package main import ( "database/sql" "flag" "fmt" "log" "os" "time" _ "github.com/go-sql-driver/mysql" ) // 配置常量 - 请根据实际情况修改 const ( // 源数据库 DSN 格式: user:password@tcp(host:port)/database?parseTime=true SOURCE_DSN = "goalfymax_prod:X6cQDaOLOifFBOMq@tcp(goalfyagent-aurora-mysql-staging.cb2sq6y2mg93.us-west-2.rds.amazonaws.com:3306)/goalfymax_prod?charset=utf8mb4&parseTime=True&loc=UTC" // 目标数据库 DSN TARGET_DSN = "root:123456@tcp(localhost:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" ) type MigrationConfig struct { ProjectID uint64 NewUID *uint64 // 如果为 nil,则使用原始 UID SourceDB *sql.DB TargetDB *sql.DB } // StreamMessage 对应 m_stream_messages 表 type StreamMessage struct { ID uint64 UID uint64 ProjectID uint64 SenderID sql.NullString SenderAgentType sql.NullString Type string MessageTimestamp string AgentGenMessageID sql.NullString Content string FormattedContent string TaskID sql.NullString TaskStatus sql.NullString TurnID sql.NullInt32 AgentMessageID sql.NullString Source string Metadata sql.NullString ReceivedAt string CreatedAt string Show int } func main() { // 命令行参数 projectID := flag.Uint64("project", 0, "Project ID to migrate (required)") newUID := flag.Uint64("uid", 0, "New UID for migrated data (optional, use 0 to keep original)") sourceDSN := flag.String("source", SOURCE_DSN, "Source database DSN") targetDSN := flag.String("target", TARGET_DSN, "Target database DSN") flag.Parse() if *projectID == 0 { fmt.Println("Usage: go run main.go -project [-uid ] [-source ] [-target ]") fmt.Println("\nExample:") fmt.Println(" go run main.go -project 123") fmt.Println(" go run main.go -project 123 -uid 456") os.Exit(1) } // 连接数据库 sourceDB, err := sql.Open("mysql", *sourceDSN) if err != nil { log.Fatalf("Failed to connect to source database: %v", err) } defer sourceDB.Close() targetDB, err := sql.Open("mysql", *targetDSN) if err != nil { log.Fatalf("Failed to connect to target database: %v", err) } defer targetDB.Close() // 测试连接 if err := sourceDB.Ping(); err != nil { log.Fatalf("Failed to ping source database: %v", err) } if err := targetDB.Ping(); err != nil { log.Fatalf("Failed to ping target database: %v", err) } config := &MigrationConfig{ ProjectID: *projectID, SourceDB: sourceDB, TargetDB: targetDB, } if *newUID != 0 { config.NewUID = newUID } log.Printf("Starting migration for project_id=%d", *projectID) if config.NewUID != nil { log.Printf("Will replace UID with: %d", *config.NewUID) } // 执行迁移 if err := migrateProject(config); err != nil { log.Fatalf("Migration failed: %v", err) } log.Println("Migration completed successfully!") } func migrateProject(config *MigrationConfig) error { // 开启事务 txSource, err := config.SourceDB.Begin() if err != nil { return fmt.Errorf("failed to begin source transaction: %w", err) } defer txSource.Rollback() txTarget, err := config.TargetDB.Begin() if err != nil { return fmt.Errorf("failed to begin target transaction: %w", err) } defer txTarget.Rollback() // 步骤0: 清除目标库中已存在的该项目数据 log.Println("Step 0: Cleaning existing data in target database...") if err := cleanExistingData(txTarget, config); err != nil { return fmt.Errorf("failed to clean existing data: %w", err) } log.Println(" Existing data cleaned") // 步骤1: 迁移 m_stream_messages (主表) log.Println("Step 1: Migrating m_stream_messages...") oldToNewStreamMsgID, err := migrateStreamMessages(txSource, txTarget, config) if err != nil { return fmt.Errorf("failed to migrate stream messages: %w", err) } log.Printf(" Migrated %d stream messages", len(oldToNewStreamMsgID)) // 步骤2: 迁移 m_stream_contents (依赖 main_message_id) log.Println("Step 2: Migrating m_stream_contents...") count, err := migrateStreamContents(txSource, txTarget, config, oldToNewStreamMsgID) if err != nil { return fmt.Errorf("failed to migrate stream contents: %w", err) } log.Printf(" Migrated %d stream contents", count) // 步骤3: 迁移 m_context_messages (依赖 main_message_id) log.Println("Step 3: Migrating m_context_messages...") count, err = migrateContextMessages(txSource, txTarget, config, oldToNewStreamMsgID) if err != nil { return fmt.Errorf("failed to migrate context messages: %w", err) } log.Printf(" Migrated %d context messages", count) // 步骤4: 迁移 m_task_messages (依赖 main_message_id) log.Println("Step 4: Migrating m_task_messages...") count, err = migrateTaskMessages(txSource, txTarget, config, oldToNewStreamMsgID) if err != nil { return fmt.Errorf("failed to migrate task messages: %w", err) } log.Printf(" Migrated %d task messages", count) // 提交事务 if err := txTarget.Commit(); err != nil { return fmt.Errorf("failed to commit target transaction: %w", err) } if err := txSource.Commit(); err != nil { return fmt.Errorf("failed to commit source transaction: %w", err) } return nil } // cleanExistingData 清除目标库中已存在的该项目数据 // 按照依赖关系逆序删除:先删除依赖表,最后删除主表 func cleanExistingData(txTarget *sql.Tx, config *MigrationConfig) error { // 首先获取该项目在目标库中的所有 stream_message IDs var streamMessageIDs []uint64 query := `SELECT id FROM m_stream_messages WHERE project_id = ?` rows, err := txTarget.Query(query, config.ProjectID) if err != nil { return fmt.Errorf("failed to query existing stream messages: %w", err) } defer rows.Close() for rows.Next() { var id uint64 if err := rows.Scan(&id); err != nil { return err } streamMessageIDs = append(streamMessageIDs, id) } if err := rows.Err(); err != nil { return err } if len(streamMessageIDs) == 0 { log.Println(" No existing data found for this project in target database") return nil } log.Printf(" Found %d existing stream messages to clean", len(streamMessageIDs)) // 步骤1: 删除 m_task_messages (依赖 main_message_id) result, err := txTarget.Exec(`DELETE FROM m_task_messages WHERE project_id = ?`, config.ProjectID) if err != nil { return fmt.Errorf("failed to delete task messages: %w", err) } if affected, _ := result.RowsAffected(); affected > 0 { log.Printf(" Deleted %d task messages", affected) } // 步骤2: 删除 m_context_messages (依赖 main_message_id) result, err = txTarget.Exec(`DELETE FROM m_context_messages WHERE project_id = ?`, config.ProjectID) if err != nil { return fmt.Errorf("failed to delete context messages: %w", err) } if affected, _ := result.RowsAffected(); affected > 0 { log.Printf(" Deleted %d context messages", affected) } // 步骤3: 删除 m_stream_contents (依赖 main_message_id) // 需要通过 main_message_id IN (...) 来删除 if len(streamMessageIDs) > 0 { // 构建 IN 子句 query := `DELETE FROM m_stream_contents WHERE main_message_id IN (` args := make([]interface{}, len(streamMessageIDs)) for i, id := range streamMessageIDs { if i > 0 { query += "," } query += "?" args[i] = id } query += ")" result, err = txTarget.Exec(query, args...) if err != nil { return fmt.Errorf("failed to delete stream contents: %w", err) } if affected, _ := result.RowsAffected(); affected > 0 { log.Printf(" Deleted %d stream contents", affected) } } // 步骤4: 删除 m_stream_messages (主表) result, err = txTarget.Exec(`DELETE FROM m_stream_messages WHERE project_id = ?`, config.ProjectID) if err != nil { return fmt.Errorf("failed to delete stream messages: %w", err) } if affected, _ := result.RowsAffected(); affected > 0 { log.Printf(" Deleted %d stream messages", affected) } return nil } // formatTimestamp 将时间字符串转换为 MySQL datetime 格式 // 输入可能是 "2025-11-07T02:37:23Z" 或已经是 "2025-11-07 02:37:23" 格式 func formatTimestamp(ts string) string { // 尝试解析 ISO 8601 格式 if t, err := time.Parse(time.RFC3339, ts); err == nil { return t.Format("2006-01-02 15:04:05") } // 尝试解析带纳秒的 ISO 8601 格式 if t, err := time.Parse(time.RFC3339Nano, ts); err == nil { return t.Format("2006-01-02 15:04:05") } // 如果已经是正确格式或解析失败,返回原值 return ts } // migrateStreamMessages 迁移主消息表,返回旧ID到新ID的映射 func migrateStreamMessages(txSource, txTarget *sql.Tx, config *MigrationConfig) (map[uint64]uint64, error) { query := `SELECT id, uid, project_id, sender_id, sender_agent_type, type, message_timestamp, agent_gen_message_id, content, formatted_content, task_id, task_status, turn_id, agent_message_id, source, metadata, received_at, created_at, ` + "`show`" + ` FROM m_stream_messages WHERE project_id = ?` rows, err := txSource.Query(query, config.ProjectID) if err != nil { return nil, err } defer rows.Close() oldToNewID := make(map[uint64]uint64) insertQuery := `INSERT INTO m_stream_messages (uid, project_id, sender_id, sender_agent_type, type, message_timestamp, agent_gen_message_id, content, formatted_content, task_id, task_status, turn_id, agent_message_id, source, metadata, received_at, created_at, ` + "`show`" + `) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` for rows.Next() { var msg StreamMessage err := rows.Scan(&msg.ID, &msg.UID, &msg.ProjectID, &msg.SenderID, &msg.SenderAgentType, &msg.Type, &msg.MessageTimestamp, &msg.AgentGenMessageID, &msg.Content, &msg.FormattedContent, &msg.TaskID, &msg.TaskStatus, &msg.TurnID, &msg.AgentMessageID, &msg.Source, &msg.Metadata, &msg.ReceivedAt, &msg.CreatedAt, &msg.Show) if err != nil { return nil, err } // 如果指定了新UID,替换 uid := msg.UID if config.NewUID != nil { uid = *config.NewUID } // 转换时间戳格式 messageTimestamp := formatTimestamp(msg.MessageTimestamp) receivedAt := formatTimestamp(msg.ReceivedAt) createdAt := formatTimestamp(msg.CreatedAt) result, err := txTarget.Exec(insertQuery, uid, msg.ProjectID, msg.SenderID, msg.SenderAgentType, msg.Type, messageTimestamp, msg.AgentGenMessageID, msg.Content, msg.FormattedContent, msg.TaskID, msg.TaskStatus, msg.TurnID, msg.AgentMessageID, msg.Source, msg.Metadata, receivedAt, createdAt, msg.Show) if err != nil { return nil, err } newID, err := result.LastInsertId() if err != nil { return nil, err } oldToNewID[msg.ID] = uint64(newID) } return oldToNewID, rows.Err() } // migrateStreamContents 迁移流式内容表 func migrateStreamContents(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) { query := `SELECT id, main_message_id, agent_gen_message_id, type, content, formatted_content, metadata, message_timestamp, ` + "`show`" + `, received_at, created_at FROM m_stream_contents WHERE main_message_id IN ( SELECT id FROM m_stream_messages WHERE project_id = ? )` rows, err := txSource.Query(query, config.ProjectID) if err != nil { return 0, err } defer rows.Close() insertQuery := `INSERT INTO m_stream_contents (main_message_id, agent_gen_message_id, type, content, formatted_content, metadata, message_timestamp, ` + "`show`" + `, received_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` count := 0 for rows.Next() { var ( id uint64 mainMessageID uint64 agentGenMessageID sql.NullString msgType string content string formattedContent string metadata sql.NullString messageTimestamp string show int receivedAt string createdAt string ) err := rows.Scan(&id, &mainMessageID, &agentGenMessageID, &msgType, &content, &formattedContent, &metadata, &messageTimestamp, &show, &receivedAt, &createdAt) if err != nil { return count, err } // 映射新的 main_message_id newMainMessageID, ok := idMap[mainMessageID] if !ok { log.Printf("Warning: main_message_id %d not found in mapping, skipping content id %d", mainMessageID, id) continue } // 转换时间戳格式 messageTimestamp = formatTimestamp(messageTimestamp) receivedAt = formatTimestamp(receivedAt) createdAt = formatTimestamp(createdAt) _, err = txTarget.Exec(insertQuery, newMainMessageID, agentGenMessageID, msgType, content, formattedContent, metadata, messageTimestamp, show, receivedAt, createdAt) if err != nil { return count, err } count++ } return count, rows.Err() } // migrateContextMessages 迁移上下文消息表 func migrateContextMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) { query := `SELECT id, main_message_id, ownership, uid, project_id, turn_id, agent_message_id, sender_id, sender_agent_type, type, source, content, formatted_content, metadata, message_timestamp, agent_gen_message_id, ` + "`show`" + `, received_at, created_at FROM m_context_messages WHERE project_id = ?` rows, err := txSource.Query(query, config.ProjectID) if err != nil { return 0, err } defer rows.Close() insertQuery := `INSERT INTO m_context_messages (main_message_id, ownership, uid, project_id, turn_id, agent_message_id, sender_id, sender_agent_type, type, source, content, formatted_content, metadata, message_timestamp, agent_gen_message_id, ` + "`show`" + `, received_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` count := 0 for rows.Next() { var ( id uint64 mainMessageID uint64 ownership string uid uint64 projectID uint64 turnID sql.NullInt32 agentMessageID sql.NullString senderID sql.NullString senderAgentType sql.NullString msgType string source string content sql.NullString formattedContent sql.NullString metadata sql.NullString messageTimestamp string agentGenMessageID sql.NullString show int receivedAt string createdAt string ) err := rows.Scan(&id, &mainMessageID, &ownership, &uid, &projectID, &turnID, &agentMessageID, &senderID, &senderAgentType, &msgType, &source, &content, &formattedContent, &metadata, &messageTimestamp, &agentGenMessageID, &show, &receivedAt, &createdAt) if err != nil { return count, err } // 映射新的 main_message_id newMainMessageID, ok := idMap[mainMessageID] if !ok { log.Printf("Warning: main_message_id %d not found in mapping, skipping context message id %d", mainMessageID, id) continue } // 如果指定了新UID,替换 targetUID := uid if config.NewUID != nil { targetUID = *config.NewUID } // 转换时间戳格式 messageTimestamp = formatTimestamp(messageTimestamp) receivedAt = formatTimestamp(receivedAt) createdAt = formatTimestamp(createdAt) _, err = txTarget.Exec(insertQuery, newMainMessageID, ownership, targetUID, projectID, turnID, agentMessageID, senderID, senderAgentType, msgType, source, content, formattedContent, metadata, messageTimestamp, agentGenMessageID, show, receivedAt, createdAt) if err != nil { return count, err } count++ } return count, rows.Err() } // migrateTaskMessages 迁移任务消息表 func migrateTaskMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) { query := `SELECT id, main_message_id, uid, project_id, sender_id, sender_agent_type, type, message_timestamp, agent_gen_message_id, content, formatted_content, task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `, received_at, created_at FROM m_task_messages WHERE project_id = ?` rows, err := txSource.Query(query, config.ProjectID) if err != nil { return 0, err } defer rows.Close() insertQuery := `INSERT INTO m_task_messages (main_message_id, uid, project_id, sender_id, sender_agent_type, type, message_timestamp, agent_gen_message_id, content, formatted_content, task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `, received_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` count := 0 for rows.Next() { var ( id uint64 mainMessageID uint64 uid uint64 projectID uint64 senderID sql.NullString senderAgentType sql.NullString msgType string messageTimestamp string agentGenMessageID sql.NullString content string formattedContent string taskStatus sql.NullString turnID sql.NullInt32 agentMessageID sql.NullString source sql.NullString metadata sql.NullString show int receivedAt string createdAt string ) err := rows.Scan(&id, &mainMessageID, &uid, &projectID, &senderID, &senderAgentType, &msgType, &messageTimestamp, &agentGenMessageID, &content, &formattedContent, &taskStatus, &turnID, &agentMessageID, &source, &metadata, &show, &receivedAt, &createdAt) if err != nil { return count, err } // 映射新的 main_message_id newMainMessageID, ok := idMap[mainMessageID] if !ok { log.Printf("Warning: main_message_id %d not found in mapping, skipping task message id %d", mainMessageID, id) continue } // 如果指定了新UID,替换 targetUID := uid if config.NewUID != nil { targetUID = *config.NewUID } // 转换时间戳格式 messageTimestamp = formatTimestamp(messageTimestamp) receivedAt = formatTimestamp(receivedAt) createdAt = formatTimestamp(createdAt) _, err = txTarget.Exec(insertQuery, newMainMessageID, targetUID, projectID, senderID, senderAgentType, msgType, messageTimestamp, agentGenMessageID, content, formattedContent, taskStatus, turnID, agentMessageID, source, metadata, show, receivedAt, createdAt) if err != nil { return count, err } count++ } return count, rows.Err() }