570 lines
18 KiB
Go
570 lines
18 KiB
Go
package main
|
||
|
||
import (
|
||
"database/sql"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
_ "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
// 配置常量 - 请根据实际情况修改
|
||
const (
|
||
// 源数据库 DSN 格式: user:password@tcp(host:port)/database?parseTime=true
|
||
SOURCE_DSN = "goalfymax_prod:X6cQDaOLOifFBOMq@tcp(goalfyagent-aurora-mysql-staging.cb2sq6y2mg93.us-west-2.rds.amazonaws.com:3306)/goalfymax_prod?charset=utf8mb4&parseTime=True&loc=UTC"
|
||
// 目标数据库 DSN
|
||
TARGET_DSN = "root:123456@tcp(localhost:3306)/test?charset=utf8mb4&parseTime=True&loc=Local"
|
||
)
|
||
|
||
type MigrationConfig struct {
|
||
ProjectID uint64
|
||
NewUID *uint64 // 如果为 nil,则使用原始 UID
|
||
SourceDB *sql.DB
|
||
TargetDB *sql.DB
|
||
}
|
||
|
||
// StreamMessage 对应 m_stream_messages 表
|
||
type StreamMessage struct {
|
||
ID uint64
|
||
UID uint64
|
||
ProjectID uint64
|
||
SenderID sql.NullString
|
||
SenderAgentType sql.NullString
|
||
Type string
|
||
MessageTimestamp string
|
||
AgentGenMessageID sql.NullString
|
||
Content string
|
||
FormattedContent string
|
||
TaskID sql.NullString
|
||
TaskStatus sql.NullString
|
||
TurnID sql.NullInt32
|
||
AgentMessageID sql.NullString
|
||
Source string
|
||
Metadata sql.NullString
|
||
ReceivedAt string
|
||
CreatedAt string
|
||
Show int
|
||
}
|
||
|
||
func main() {
|
||
// 命令行参数
|
||
projectID := flag.Uint64("project", 0, "Project ID to migrate (required)")
|
||
newUID := flag.Uint64("uid", 0, "New UID for migrated data (optional, use 0 to keep original)")
|
||
sourceDSN := flag.String("source", SOURCE_DSN, "Source database DSN")
|
||
targetDSN := flag.String("target", TARGET_DSN, "Target database DSN")
|
||
flag.Parse()
|
||
|
||
if *projectID == 0 {
|
||
fmt.Println("Usage: go run main.go -project <project_id> [-uid <new_uid>] [-source <dsn>] [-target <dsn>]")
|
||
fmt.Println("\nExample:")
|
||
fmt.Println(" go run main.go -project 123")
|
||
fmt.Println(" go run main.go -project 123 -uid 456")
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 连接数据库
|
||
sourceDB, err := sql.Open("mysql", *sourceDSN)
|
||
if err != nil {
|
||
log.Fatalf("Failed to connect to source database: %v", err)
|
||
}
|
||
defer sourceDB.Close()
|
||
|
||
targetDB, err := sql.Open("mysql", *targetDSN)
|
||
if err != nil {
|
||
log.Fatalf("Failed to connect to target database: %v", err)
|
||
}
|
||
defer targetDB.Close()
|
||
|
||
// 测试连接
|
||
if err := sourceDB.Ping(); err != nil {
|
||
log.Fatalf("Failed to ping source database: %v", err)
|
||
}
|
||
if err := targetDB.Ping(); err != nil {
|
||
log.Fatalf("Failed to ping target database: %v", err)
|
||
}
|
||
|
||
config := &MigrationConfig{
|
||
ProjectID: *projectID,
|
||
SourceDB: sourceDB,
|
||
TargetDB: targetDB,
|
||
}
|
||
|
||
if *newUID != 0 {
|
||
config.NewUID = newUID
|
||
}
|
||
|
||
log.Printf("Starting migration for project_id=%d", *projectID)
|
||
if config.NewUID != nil {
|
||
log.Printf("Will replace UID with: %d", *config.NewUID)
|
||
}
|
||
|
||
// 执行迁移
|
||
if err := migrateProject(config); err != nil {
|
||
log.Fatalf("Migration failed: %v", err)
|
||
}
|
||
|
||
log.Println("Migration completed successfully!")
|
||
}
|
||
|
||
func migrateProject(config *MigrationConfig) error {
|
||
// 开启事务
|
||
txSource, err := config.SourceDB.Begin()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to begin source transaction: %w", err)
|
||
}
|
||
defer txSource.Rollback()
|
||
|
||
txTarget, err := config.TargetDB.Begin()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to begin target transaction: %w", err)
|
||
}
|
||
defer txTarget.Rollback()
|
||
|
||
// 步骤0: 清除目标库中已存在的该项目数据
|
||
log.Println("Step 0: Cleaning existing data in target database...")
|
||
if err := cleanExistingData(txTarget, config); err != nil {
|
||
return fmt.Errorf("failed to clean existing data: %w", err)
|
||
}
|
||
log.Println(" Existing data cleaned")
|
||
|
||
// 步骤1: 迁移 m_stream_messages (主表)
|
||
log.Println("Step 1: Migrating m_stream_messages...")
|
||
oldToNewStreamMsgID, err := migrateStreamMessages(txSource, txTarget, config)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to migrate stream messages: %w", err)
|
||
}
|
||
log.Printf(" Migrated %d stream messages", len(oldToNewStreamMsgID))
|
||
|
||
// 步骤2: 迁移 m_stream_contents (依赖 main_message_id)
|
||
log.Println("Step 2: Migrating m_stream_contents...")
|
||
count, err := migrateStreamContents(txSource, txTarget, config, oldToNewStreamMsgID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to migrate stream contents: %w", err)
|
||
}
|
||
log.Printf(" Migrated %d stream contents", count)
|
||
|
||
// 步骤3: 迁移 m_context_messages (依赖 main_message_id)
|
||
log.Println("Step 3: Migrating m_context_messages...")
|
||
count, err = migrateContextMessages(txSource, txTarget, config, oldToNewStreamMsgID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to migrate context messages: %w", err)
|
||
}
|
||
log.Printf(" Migrated %d context messages", count)
|
||
|
||
// 步骤4: 迁移 m_task_messages (依赖 main_message_id)
|
||
log.Println("Step 4: Migrating m_task_messages...")
|
||
count, err = migrateTaskMessages(txSource, txTarget, config, oldToNewStreamMsgID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to migrate task messages: %w", err)
|
||
}
|
||
log.Printf(" Migrated %d task messages", count)
|
||
|
||
// 提交事务
|
||
if err := txTarget.Commit(); err != nil {
|
||
return fmt.Errorf("failed to commit target transaction: %w", err)
|
||
}
|
||
if err := txSource.Commit(); err != nil {
|
||
return fmt.Errorf("failed to commit source transaction: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// cleanExistingData 清除目标库中已存在的该项目数据
|
||
// 按照依赖关系逆序删除:先删除依赖表,最后删除主表
|
||
func cleanExistingData(txTarget *sql.Tx, config *MigrationConfig) error {
|
||
// 首先获取该项目在目标库中的所有 stream_message IDs
|
||
var streamMessageIDs []uint64
|
||
query := `SELECT id FROM m_stream_messages WHERE project_id = ?`
|
||
rows, err := txTarget.Query(query, config.ProjectID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to query existing stream messages: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
for rows.Next() {
|
||
var id uint64
|
||
if err := rows.Scan(&id); err != nil {
|
||
return err
|
||
}
|
||
streamMessageIDs = append(streamMessageIDs, id)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(streamMessageIDs) == 0 {
|
||
log.Println(" No existing data found for this project in target database")
|
||
return nil
|
||
}
|
||
|
||
log.Printf(" Found %d existing stream messages to clean", len(streamMessageIDs))
|
||
|
||
// 步骤1: 删除 m_task_messages (依赖 main_message_id)
|
||
result, err := txTarget.Exec(`DELETE FROM m_task_messages WHERE project_id = ?`, config.ProjectID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete task messages: %w", err)
|
||
}
|
||
if affected, _ := result.RowsAffected(); affected > 0 {
|
||
log.Printf(" Deleted %d task messages", affected)
|
||
}
|
||
|
||
// 步骤2: 删除 m_context_messages (依赖 main_message_id)
|
||
result, err = txTarget.Exec(`DELETE FROM m_context_messages WHERE project_id = ?`, config.ProjectID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete context messages: %w", err)
|
||
}
|
||
if affected, _ := result.RowsAffected(); affected > 0 {
|
||
log.Printf(" Deleted %d context messages", affected)
|
||
}
|
||
|
||
// 步骤3: 删除 m_stream_contents (依赖 main_message_id)
|
||
// 需要通过 main_message_id IN (...) 来删除
|
||
if len(streamMessageIDs) > 0 {
|
||
// 构建 IN 子句
|
||
query := `DELETE FROM m_stream_contents WHERE main_message_id IN (`
|
||
args := make([]interface{}, len(streamMessageIDs))
|
||
for i, id := range streamMessageIDs {
|
||
if i > 0 {
|
||
query += ","
|
||
}
|
||
query += "?"
|
||
args[i] = id
|
||
}
|
||
query += ")"
|
||
|
||
result, err = txTarget.Exec(query, args...)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete stream contents: %w", err)
|
||
}
|
||
if affected, _ := result.RowsAffected(); affected > 0 {
|
||
log.Printf(" Deleted %d stream contents", affected)
|
||
}
|
||
}
|
||
|
||
// 步骤4: 删除 m_stream_messages (主表)
|
||
result, err = txTarget.Exec(`DELETE FROM m_stream_messages WHERE project_id = ?`, config.ProjectID)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete stream messages: %w", err)
|
||
}
|
||
if affected, _ := result.RowsAffected(); affected > 0 {
|
||
log.Printf(" Deleted %d stream messages", affected)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// formatTimestamp 将时间字符串转换为 MySQL datetime 格式
|
||
// 输入可能是 "2025-11-07T02:37:23Z" 或已经是 "2025-11-07 02:37:23" 格式
|
||
func formatTimestamp(ts string) string {
|
||
// 尝试解析 ISO 8601 格式
|
||
if t, err := time.Parse(time.RFC3339, ts); err == nil {
|
||
return t.Format("2006-01-02 15:04:05")
|
||
}
|
||
// 尝试解析带纳秒的 ISO 8601 格式
|
||
if t, err := time.Parse(time.RFC3339Nano, ts); err == nil {
|
||
return t.Format("2006-01-02 15:04:05")
|
||
}
|
||
// 如果已经是正确格式或解析失败,返回原值
|
||
return ts
|
||
}
|
||
|
||
// migrateStreamMessages 迁移主消息表,返回旧ID到新ID的映射
|
||
func migrateStreamMessages(txSource, txTarget *sql.Tx, config *MigrationConfig) (map[uint64]uint64, error) {
|
||
query := `SELECT id, uid, project_id, sender_id, sender_agent_type, type,
|
||
message_timestamp, agent_gen_message_id, content, formatted_content,
|
||
task_id, task_status, turn_id, agent_message_id, source, metadata,
|
||
received_at, created_at, ` + "`show`" + `
|
||
FROM m_stream_messages WHERE project_id = ?`
|
||
|
||
rows, err := txSource.Query(query, config.ProjectID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
oldToNewID := make(map[uint64]uint64)
|
||
insertQuery := `INSERT INTO m_stream_messages
|
||
(uid, project_id, sender_id, sender_agent_type, type, message_timestamp,
|
||
agent_gen_message_id, content, formatted_content, task_id, task_status,
|
||
turn_id, agent_message_id, source, metadata, received_at, created_at, ` + "`show`" + `)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||
|
||
for rows.Next() {
|
||
var msg StreamMessage
|
||
err := rows.Scan(&msg.ID, &msg.UID, &msg.ProjectID, &msg.SenderID, &msg.SenderAgentType,
|
||
&msg.Type, &msg.MessageTimestamp, &msg.AgentGenMessageID, &msg.Content,
|
||
&msg.FormattedContent, &msg.TaskID, &msg.TaskStatus, &msg.TurnID,
|
||
&msg.AgentMessageID, &msg.Source, &msg.Metadata, &msg.ReceivedAt,
|
||
&msg.CreatedAt, &msg.Show)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 如果指定了新UID,替换
|
||
uid := msg.UID
|
||
if config.NewUID != nil {
|
||
uid = *config.NewUID
|
||
}
|
||
|
||
// 转换时间戳格式
|
||
messageTimestamp := formatTimestamp(msg.MessageTimestamp)
|
||
receivedAt := formatTimestamp(msg.ReceivedAt)
|
||
createdAt := formatTimestamp(msg.CreatedAt)
|
||
|
||
result, err := txTarget.Exec(insertQuery,
|
||
uid, msg.ProjectID, msg.SenderID, msg.SenderAgentType, msg.Type,
|
||
messageTimestamp, msg.AgentGenMessageID, msg.Content,
|
||
msg.FormattedContent, msg.TaskID, msg.TaskStatus, msg.TurnID,
|
||
msg.AgentMessageID, msg.Source, msg.Metadata, receivedAt,
|
||
createdAt, msg.Show)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
newID, err := result.LastInsertId()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
oldToNewID[msg.ID] = uint64(newID)
|
||
}
|
||
|
||
return oldToNewID, rows.Err()
|
||
}
|
||
|
||
// migrateStreamContents 迁移流式内容表
|
||
func migrateStreamContents(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
|
||
query := `SELECT id, main_message_id, agent_gen_message_id, type, content,
|
||
formatted_content, metadata, message_timestamp, ` + "`show`" + `, received_at, created_at
|
||
FROM m_stream_contents
|
||
WHERE main_message_id IN (
|
||
SELECT id FROM m_stream_messages WHERE project_id = ?
|
||
)`
|
||
|
||
rows, err := txSource.Query(query, config.ProjectID)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
insertQuery := `INSERT INTO m_stream_contents
|
||
(main_message_id, agent_gen_message_id, type, content, formatted_content,
|
||
metadata, message_timestamp, ` + "`show`" + `, received_at, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||
|
||
count := 0
|
||
for rows.Next() {
|
||
var (
|
||
id uint64
|
||
mainMessageID uint64
|
||
agentGenMessageID sql.NullString
|
||
msgType string
|
||
content string
|
||
formattedContent string
|
||
metadata sql.NullString
|
||
messageTimestamp string
|
||
show int
|
||
receivedAt string
|
||
createdAt string
|
||
)
|
||
|
||
err := rows.Scan(&id, &mainMessageID, &agentGenMessageID, &msgType, &content,
|
||
&formattedContent, &metadata, &messageTimestamp, &show, &receivedAt, &createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
|
||
// 映射新的 main_message_id
|
||
newMainMessageID, ok := idMap[mainMessageID]
|
||
if !ok {
|
||
log.Printf("Warning: main_message_id %d not found in mapping, skipping content id %d", mainMessageID, id)
|
||
continue
|
||
}
|
||
|
||
// 转换时间戳格式
|
||
messageTimestamp = formatTimestamp(messageTimestamp)
|
||
receivedAt = formatTimestamp(receivedAt)
|
||
createdAt = formatTimestamp(createdAt)
|
||
|
||
_, err = txTarget.Exec(insertQuery, newMainMessageID, agentGenMessageID, msgType,
|
||
content, formattedContent, metadata, messageTimestamp, show, receivedAt, createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
count++
|
||
}
|
||
|
||
return count, rows.Err()
|
||
}
|
||
|
||
// migrateContextMessages 迁移上下文消息表
|
||
func migrateContextMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
|
||
query := `SELECT id, main_message_id, ownership, uid, project_id, turn_id,
|
||
agent_message_id, sender_id, sender_agent_type, type, source, content,
|
||
formatted_content, metadata, message_timestamp, agent_gen_message_id,
|
||
` + "`show`" + `, received_at, created_at
|
||
FROM m_context_messages WHERE project_id = ?`
|
||
|
||
rows, err := txSource.Query(query, config.ProjectID)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
insertQuery := `INSERT INTO m_context_messages
|
||
(main_message_id, ownership, uid, project_id, turn_id, agent_message_id,
|
||
sender_id, sender_agent_type, type, source, content, formatted_content,
|
||
metadata, message_timestamp, agent_gen_message_id, ` + "`show`" + `, received_at, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||
|
||
count := 0
|
||
for rows.Next() {
|
||
var (
|
||
id uint64
|
||
mainMessageID uint64
|
||
ownership string
|
||
uid uint64
|
||
projectID uint64
|
||
turnID sql.NullInt32
|
||
agentMessageID sql.NullString
|
||
senderID sql.NullString
|
||
senderAgentType sql.NullString
|
||
msgType string
|
||
source string
|
||
content sql.NullString
|
||
formattedContent sql.NullString
|
||
metadata sql.NullString
|
||
messageTimestamp string
|
||
agentGenMessageID sql.NullString
|
||
show int
|
||
receivedAt string
|
||
createdAt string
|
||
)
|
||
|
||
err := rows.Scan(&id, &mainMessageID, &ownership, &uid, &projectID, &turnID,
|
||
&agentMessageID, &senderID, &senderAgentType, &msgType, &source, &content,
|
||
&formattedContent, &metadata, &messageTimestamp, &agentGenMessageID,
|
||
&show, &receivedAt, &createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
|
||
// 映射新的 main_message_id
|
||
newMainMessageID, ok := idMap[mainMessageID]
|
||
if !ok {
|
||
log.Printf("Warning: main_message_id %d not found in mapping, skipping context message id %d", mainMessageID, id)
|
||
continue
|
||
}
|
||
|
||
// 如果指定了新UID,替换
|
||
targetUID := uid
|
||
if config.NewUID != nil {
|
||
targetUID = *config.NewUID
|
||
}
|
||
|
||
// 转换时间戳格式
|
||
messageTimestamp = formatTimestamp(messageTimestamp)
|
||
receivedAt = formatTimestamp(receivedAt)
|
||
createdAt = formatTimestamp(createdAt)
|
||
|
||
_, err = txTarget.Exec(insertQuery, newMainMessageID, ownership, targetUID, projectID,
|
||
turnID, agentMessageID, senderID, senderAgentType, msgType, source, content,
|
||
formattedContent, metadata, messageTimestamp, agentGenMessageID, show,
|
||
receivedAt, createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
count++
|
||
}
|
||
|
||
return count, rows.Err()
|
||
}
|
||
|
||
// migrateTaskMessages 迁移任务消息表
|
||
func migrateTaskMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
|
||
query := `SELECT id, main_message_id, uid, project_id, sender_id, sender_agent_type,
|
||
type, message_timestamp, agent_gen_message_id, content, formatted_content,
|
||
task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `,
|
||
received_at, created_at
|
||
FROM m_task_messages WHERE project_id = ?`
|
||
|
||
rows, err := txSource.Query(query, config.ProjectID)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
insertQuery := `INSERT INTO m_task_messages
|
||
(main_message_id, uid, project_id, sender_id, sender_agent_type, type,
|
||
message_timestamp, agent_gen_message_id, content, formatted_content,
|
||
task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `,
|
||
received_at, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||
|
||
count := 0
|
||
for rows.Next() {
|
||
var (
|
||
id uint64
|
||
mainMessageID uint64
|
||
uid uint64
|
||
projectID uint64
|
||
senderID sql.NullString
|
||
senderAgentType sql.NullString
|
||
msgType string
|
||
messageTimestamp string
|
||
agentGenMessageID sql.NullString
|
||
content string
|
||
formattedContent string
|
||
taskStatus sql.NullString
|
||
turnID sql.NullInt32
|
||
agentMessageID sql.NullString
|
||
source sql.NullString
|
||
metadata sql.NullString
|
||
show int
|
||
receivedAt string
|
||
createdAt string
|
||
)
|
||
|
||
err := rows.Scan(&id, &mainMessageID, &uid, &projectID, &senderID, &senderAgentType,
|
||
&msgType, &messageTimestamp, &agentGenMessageID, &content, &formattedContent,
|
||
&taskStatus, &turnID, &agentMessageID, &source, &metadata, &show,
|
||
&receivedAt, &createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
|
||
// 映射新的 main_message_id
|
||
newMainMessageID, ok := idMap[mainMessageID]
|
||
if !ok {
|
||
log.Printf("Warning: main_message_id %d not found in mapping, skipping task message id %d", mainMessageID, id)
|
||
continue
|
||
}
|
||
|
||
// 如果指定了新UID,替换
|
||
targetUID := uid
|
||
if config.NewUID != nil {
|
||
targetUID = *config.NewUID
|
||
}
|
||
|
||
// 转换时间戳格式
|
||
messageTimestamp = formatTimestamp(messageTimestamp)
|
||
receivedAt = formatTimestamp(receivedAt)
|
||
createdAt = formatTimestamp(createdAt)
|
||
|
||
_, err = txTarget.Exec(insertQuery, newMainMessageID, targetUID, projectID, senderID,
|
||
senderAgentType, msgType, messageTimestamp, agentGenMessageID, content,
|
||
formattedContent, taskStatus, turnID, agentMessageID, source, metadata,
|
||
show, receivedAt, createdAt)
|
||
if err != nil {
|
||
return count, err
|
||
}
|
||
count++
|
||
}
|
||
|
||
return count, rows.Err()
|
||
}
|