Files
playground/cmd/message-migrator/main.go
XuanLee-HEALER df39693dff commit
2025-11-10 15:30:21 +08:00

570 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
)
// 配置常量 - 请根据实际情况修改
const (
// 源数据库 DSN 格式: user:password@tcp(host:port)/database?parseTime=true
SOURCE_DSN = "goalfymax_prod:X6cQDaOLOifFBOMq@tcp(goalfyagent-aurora-mysql-staging.cb2sq6y2mg93.us-west-2.rds.amazonaws.com:3306)/goalfymax_prod?charset=utf8mb4&parseTime=True&loc=UTC"
// 目标数据库 DSN
TARGET_DSN = "root:123456@tcp(localhost:3306)/test?charset=utf8mb4&parseTime=True&loc=Local"
)
type MigrationConfig struct {
ProjectID uint64
NewUID *uint64 // 如果为 nil则使用原始 UID
SourceDB *sql.DB
TargetDB *sql.DB
}
// StreamMessage 对应 m_stream_messages 表
type StreamMessage struct {
ID uint64
UID uint64
ProjectID uint64
SenderID sql.NullString
SenderAgentType sql.NullString
Type string
MessageTimestamp string
AgentGenMessageID sql.NullString
Content string
FormattedContent string
TaskID sql.NullString
TaskStatus sql.NullString
TurnID sql.NullInt32
AgentMessageID sql.NullString
Source string
Metadata sql.NullString
ReceivedAt string
CreatedAt string
Show int
}
func main() {
// 命令行参数
projectID := flag.Uint64("project", 0, "Project ID to migrate (required)")
newUID := flag.Uint64("uid", 0, "New UID for migrated data (optional, use 0 to keep original)")
sourceDSN := flag.String("source", SOURCE_DSN, "Source database DSN")
targetDSN := flag.String("target", TARGET_DSN, "Target database DSN")
flag.Parse()
if *projectID == 0 {
fmt.Println("Usage: go run main.go -project <project_id> [-uid <new_uid>] [-source <dsn>] [-target <dsn>]")
fmt.Println("\nExample:")
fmt.Println(" go run main.go -project 123")
fmt.Println(" go run main.go -project 123 -uid 456")
os.Exit(1)
}
// 连接数据库
sourceDB, err := sql.Open("mysql", *sourceDSN)
if err != nil {
log.Fatalf("Failed to connect to source database: %v", err)
}
defer sourceDB.Close()
targetDB, err := sql.Open("mysql", *targetDSN)
if err != nil {
log.Fatalf("Failed to connect to target database: %v", err)
}
defer targetDB.Close()
// 测试连接
if err := sourceDB.Ping(); err != nil {
log.Fatalf("Failed to ping source database: %v", err)
}
if err := targetDB.Ping(); err != nil {
log.Fatalf("Failed to ping target database: %v", err)
}
config := &MigrationConfig{
ProjectID: *projectID,
SourceDB: sourceDB,
TargetDB: targetDB,
}
if *newUID != 0 {
config.NewUID = newUID
}
log.Printf("Starting migration for project_id=%d", *projectID)
if config.NewUID != nil {
log.Printf("Will replace UID with: %d", *config.NewUID)
}
// 执行迁移
if err := migrateProject(config); err != nil {
log.Fatalf("Migration failed: %v", err)
}
log.Println("Migration completed successfully!")
}
func migrateProject(config *MigrationConfig) error {
// 开启事务
txSource, err := config.SourceDB.Begin()
if err != nil {
return fmt.Errorf("failed to begin source transaction: %w", err)
}
defer txSource.Rollback()
txTarget, err := config.TargetDB.Begin()
if err != nil {
return fmt.Errorf("failed to begin target transaction: %w", err)
}
defer txTarget.Rollback()
// 步骤0: 清除目标库中已存在的该项目数据
log.Println("Step 0: Cleaning existing data in target database...")
if err := cleanExistingData(txTarget, config); err != nil {
return fmt.Errorf("failed to clean existing data: %w", err)
}
log.Println(" Existing data cleaned")
// 步骤1: 迁移 m_stream_messages (主表)
log.Println("Step 1: Migrating m_stream_messages...")
oldToNewStreamMsgID, err := migrateStreamMessages(txSource, txTarget, config)
if err != nil {
return fmt.Errorf("failed to migrate stream messages: %w", err)
}
log.Printf(" Migrated %d stream messages", len(oldToNewStreamMsgID))
// 步骤2: 迁移 m_stream_contents (依赖 main_message_id)
log.Println("Step 2: Migrating m_stream_contents...")
count, err := migrateStreamContents(txSource, txTarget, config, oldToNewStreamMsgID)
if err != nil {
return fmt.Errorf("failed to migrate stream contents: %w", err)
}
log.Printf(" Migrated %d stream contents", count)
// 步骤3: 迁移 m_context_messages (依赖 main_message_id)
log.Println("Step 3: Migrating m_context_messages...")
count, err = migrateContextMessages(txSource, txTarget, config, oldToNewStreamMsgID)
if err != nil {
return fmt.Errorf("failed to migrate context messages: %w", err)
}
log.Printf(" Migrated %d context messages", count)
// 步骤4: 迁移 m_task_messages (依赖 main_message_id)
log.Println("Step 4: Migrating m_task_messages...")
count, err = migrateTaskMessages(txSource, txTarget, config, oldToNewStreamMsgID)
if err != nil {
return fmt.Errorf("failed to migrate task messages: %w", err)
}
log.Printf(" Migrated %d task messages", count)
// 提交事务
if err := txTarget.Commit(); err != nil {
return fmt.Errorf("failed to commit target transaction: %w", err)
}
if err := txSource.Commit(); err != nil {
return fmt.Errorf("failed to commit source transaction: %w", err)
}
return nil
}
// cleanExistingData 清除目标库中已存在的该项目数据
// 按照依赖关系逆序删除:先删除依赖表,最后删除主表
func cleanExistingData(txTarget *sql.Tx, config *MigrationConfig) error {
// 首先获取该项目在目标库中的所有 stream_message IDs
var streamMessageIDs []uint64
query := `SELECT id FROM m_stream_messages WHERE project_id = ?`
rows, err := txTarget.Query(query, config.ProjectID)
if err != nil {
return fmt.Errorf("failed to query existing stream messages: %w", err)
}
defer rows.Close()
for rows.Next() {
var id uint64
if err := rows.Scan(&id); err != nil {
return err
}
streamMessageIDs = append(streamMessageIDs, id)
}
if err := rows.Err(); err != nil {
return err
}
if len(streamMessageIDs) == 0 {
log.Println(" No existing data found for this project in target database")
return nil
}
log.Printf(" Found %d existing stream messages to clean", len(streamMessageIDs))
// 步骤1: 删除 m_task_messages (依赖 main_message_id)
result, err := txTarget.Exec(`DELETE FROM m_task_messages WHERE project_id = ?`, config.ProjectID)
if err != nil {
return fmt.Errorf("failed to delete task messages: %w", err)
}
if affected, _ := result.RowsAffected(); affected > 0 {
log.Printf(" Deleted %d task messages", affected)
}
// 步骤2: 删除 m_context_messages (依赖 main_message_id)
result, err = txTarget.Exec(`DELETE FROM m_context_messages WHERE project_id = ?`, config.ProjectID)
if err != nil {
return fmt.Errorf("failed to delete context messages: %w", err)
}
if affected, _ := result.RowsAffected(); affected > 0 {
log.Printf(" Deleted %d context messages", affected)
}
// 步骤3: 删除 m_stream_contents (依赖 main_message_id)
// 需要通过 main_message_id IN (...) 来删除
if len(streamMessageIDs) > 0 {
// 构建 IN 子句
query := `DELETE FROM m_stream_contents WHERE main_message_id IN (`
args := make([]interface{}, len(streamMessageIDs))
for i, id := range streamMessageIDs {
if i > 0 {
query += ","
}
query += "?"
args[i] = id
}
query += ")"
result, err = txTarget.Exec(query, args...)
if err != nil {
return fmt.Errorf("failed to delete stream contents: %w", err)
}
if affected, _ := result.RowsAffected(); affected > 0 {
log.Printf(" Deleted %d stream contents", affected)
}
}
// 步骤4: 删除 m_stream_messages (主表)
result, err = txTarget.Exec(`DELETE FROM m_stream_messages WHERE project_id = ?`, config.ProjectID)
if err != nil {
return fmt.Errorf("failed to delete stream messages: %w", err)
}
if affected, _ := result.RowsAffected(); affected > 0 {
log.Printf(" Deleted %d stream messages", affected)
}
return nil
}
// formatTimestamp 将时间字符串转换为 MySQL datetime 格式
// 输入可能是 "2025-11-07T02:37:23Z" 或已经是 "2025-11-07 02:37:23" 格式
func formatTimestamp(ts string) string {
// 尝试解析 ISO 8601 格式
if t, err := time.Parse(time.RFC3339, ts); err == nil {
return t.Format("2006-01-02 15:04:05")
}
// 尝试解析带纳秒的 ISO 8601 格式
if t, err := time.Parse(time.RFC3339Nano, ts); err == nil {
return t.Format("2006-01-02 15:04:05")
}
// 如果已经是正确格式或解析失败,返回原值
return ts
}
// migrateStreamMessages 迁移主消息表返回旧ID到新ID的映射
func migrateStreamMessages(txSource, txTarget *sql.Tx, config *MigrationConfig) (map[uint64]uint64, error) {
query := `SELECT id, uid, project_id, sender_id, sender_agent_type, type,
message_timestamp, agent_gen_message_id, content, formatted_content,
task_id, task_status, turn_id, agent_message_id, source, metadata,
received_at, created_at, ` + "`show`" + `
FROM m_stream_messages WHERE project_id = ?`
rows, err := txSource.Query(query, config.ProjectID)
if err != nil {
return nil, err
}
defer rows.Close()
oldToNewID := make(map[uint64]uint64)
insertQuery := `INSERT INTO m_stream_messages
(uid, project_id, sender_id, sender_agent_type, type, message_timestamp,
agent_gen_message_id, content, formatted_content, task_id, task_status,
turn_id, agent_message_id, source, metadata, received_at, created_at, ` + "`show`" + `)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
for rows.Next() {
var msg StreamMessage
err := rows.Scan(&msg.ID, &msg.UID, &msg.ProjectID, &msg.SenderID, &msg.SenderAgentType,
&msg.Type, &msg.MessageTimestamp, &msg.AgentGenMessageID, &msg.Content,
&msg.FormattedContent, &msg.TaskID, &msg.TaskStatus, &msg.TurnID,
&msg.AgentMessageID, &msg.Source, &msg.Metadata, &msg.ReceivedAt,
&msg.CreatedAt, &msg.Show)
if err != nil {
return nil, err
}
// 如果指定了新UID替换
uid := msg.UID
if config.NewUID != nil {
uid = *config.NewUID
}
// 转换时间戳格式
messageTimestamp := formatTimestamp(msg.MessageTimestamp)
receivedAt := formatTimestamp(msg.ReceivedAt)
createdAt := formatTimestamp(msg.CreatedAt)
result, err := txTarget.Exec(insertQuery,
uid, msg.ProjectID, msg.SenderID, msg.SenderAgentType, msg.Type,
messageTimestamp, msg.AgentGenMessageID, msg.Content,
msg.FormattedContent, msg.TaskID, msg.TaskStatus, msg.TurnID,
msg.AgentMessageID, msg.Source, msg.Metadata, receivedAt,
createdAt, msg.Show)
if err != nil {
return nil, err
}
newID, err := result.LastInsertId()
if err != nil {
return nil, err
}
oldToNewID[msg.ID] = uint64(newID)
}
return oldToNewID, rows.Err()
}
// migrateStreamContents 迁移流式内容表
func migrateStreamContents(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
query := `SELECT id, main_message_id, agent_gen_message_id, type, content,
formatted_content, metadata, message_timestamp, ` + "`show`" + `, received_at, created_at
FROM m_stream_contents
WHERE main_message_id IN (
SELECT id FROM m_stream_messages WHERE project_id = ?
)`
rows, err := txSource.Query(query, config.ProjectID)
if err != nil {
return 0, err
}
defer rows.Close()
insertQuery := `INSERT INTO m_stream_contents
(main_message_id, agent_gen_message_id, type, content, formatted_content,
metadata, message_timestamp, ` + "`show`" + `, received_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
count := 0
for rows.Next() {
var (
id uint64
mainMessageID uint64
agentGenMessageID sql.NullString
msgType string
content string
formattedContent string
metadata sql.NullString
messageTimestamp string
show int
receivedAt string
createdAt string
)
err := rows.Scan(&id, &mainMessageID, &agentGenMessageID, &msgType, &content,
&formattedContent, &metadata, &messageTimestamp, &show, &receivedAt, &createdAt)
if err != nil {
return count, err
}
// 映射新的 main_message_id
newMainMessageID, ok := idMap[mainMessageID]
if !ok {
log.Printf("Warning: main_message_id %d not found in mapping, skipping content id %d", mainMessageID, id)
continue
}
// 转换时间戳格式
messageTimestamp = formatTimestamp(messageTimestamp)
receivedAt = formatTimestamp(receivedAt)
createdAt = formatTimestamp(createdAt)
_, err = txTarget.Exec(insertQuery, newMainMessageID, agentGenMessageID, msgType,
content, formattedContent, metadata, messageTimestamp, show, receivedAt, createdAt)
if err != nil {
return count, err
}
count++
}
return count, rows.Err()
}
// migrateContextMessages 迁移上下文消息表
func migrateContextMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
query := `SELECT id, main_message_id, ownership, uid, project_id, turn_id,
agent_message_id, sender_id, sender_agent_type, type, source, content,
formatted_content, metadata, message_timestamp, agent_gen_message_id,
` + "`show`" + `, received_at, created_at
FROM m_context_messages WHERE project_id = ?`
rows, err := txSource.Query(query, config.ProjectID)
if err != nil {
return 0, err
}
defer rows.Close()
insertQuery := `INSERT INTO m_context_messages
(main_message_id, ownership, uid, project_id, turn_id, agent_message_id,
sender_id, sender_agent_type, type, source, content, formatted_content,
metadata, message_timestamp, agent_gen_message_id, ` + "`show`" + `, received_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
count := 0
for rows.Next() {
var (
id uint64
mainMessageID uint64
ownership string
uid uint64
projectID uint64
turnID sql.NullInt32
agentMessageID sql.NullString
senderID sql.NullString
senderAgentType sql.NullString
msgType string
source string
content sql.NullString
formattedContent sql.NullString
metadata sql.NullString
messageTimestamp string
agentGenMessageID sql.NullString
show int
receivedAt string
createdAt string
)
err := rows.Scan(&id, &mainMessageID, &ownership, &uid, &projectID, &turnID,
&agentMessageID, &senderID, &senderAgentType, &msgType, &source, &content,
&formattedContent, &metadata, &messageTimestamp, &agentGenMessageID,
&show, &receivedAt, &createdAt)
if err != nil {
return count, err
}
// 映射新的 main_message_id
newMainMessageID, ok := idMap[mainMessageID]
if !ok {
log.Printf("Warning: main_message_id %d not found in mapping, skipping context message id %d", mainMessageID, id)
continue
}
// 如果指定了新UID替换
targetUID := uid
if config.NewUID != nil {
targetUID = *config.NewUID
}
// 转换时间戳格式
messageTimestamp = formatTimestamp(messageTimestamp)
receivedAt = formatTimestamp(receivedAt)
createdAt = formatTimestamp(createdAt)
_, err = txTarget.Exec(insertQuery, newMainMessageID, ownership, targetUID, projectID,
turnID, agentMessageID, senderID, senderAgentType, msgType, source, content,
formattedContent, metadata, messageTimestamp, agentGenMessageID, show,
receivedAt, createdAt)
if err != nil {
return count, err
}
count++
}
return count, rows.Err()
}
// migrateTaskMessages 迁移任务消息表
func migrateTaskMessages(txSource, txTarget *sql.Tx, config *MigrationConfig, idMap map[uint64]uint64) (int, error) {
query := `SELECT id, main_message_id, uid, project_id, sender_id, sender_agent_type,
type, message_timestamp, agent_gen_message_id, content, formatted_content,
task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `,
received_at, created_at
FROM m_task_messages WHERE project_id = ?`
rows, err := txSource.Query(query, config.ProjectID)
if err != nil {
return 0, err
}
defer rows.Close()
insertQuery := `INSERT INTO m_task_messages
(main_message_id, uid, project_id, sender_id, sender_agent_type, type,
message_timestamp, agent_gen_message_id, content, formatted_content,
task_status, turn_id, agent_message_id, source, metadata, ` + "`show`" + `,
received_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
count := 0
for rows.Next() {
var (
id uint64
mainMessageID uint64
uid uint64
projectID uint64
senderID sql.NullString
senderAgentType sql.NullString
msgType string
messageTimestamp string
agentGenMessageID sql.NullString
content string
formattedContent string
taskStatus sql.NullString
turnID sql.NullInt32
agentMessageID sql.NullString
source sql.NullString
metadata sql.NullString
show int
receivedAt string
createdAt string
)
err := rows.Scan(&id, &mainMessageID, &uid, &projectID, &senderID, &senderAgentType,
&msgType, &messageTimestamp, &agentGenMessageID, &content, &formattedContent,
&taskStatus, &turnID, &agentMessageID, &source, &metadata, &show,
&receivedAt, &createdAt)
if err != nil {
return count, err
}
// 映射新的 main_message_id
newMainMessageID, ok := idMap[mainMessageID]
if !ok {
log.Printf("Warning: main_message_id %d not found in mapping, skipping task message id %d", mainMessageID, id)
continue
}
// 如果指定了新UID替换
targetUID := uid
if config.NewUID != nil {
targetUID = *config.NewUID
}
// 转换时间戳格式
messageTimestamp = formatTimestamp(messageTimestamp)
receivedAt = formatTimestamp(receivedAt)
createdAt = formatTimestamp(createdAt)
_, err = txTarget.Exec(insertQuery, newMainMessageID, targetUID, projectID, senderID,
senderAgentType, msgType, messageTimestamp, agentGenMessageID, content,
formattedContent, taskStatus, turnID, agentMessageID, source, metadata,
show, receivedAt, createdAt)
if err != nil {
return count, err
}
count++
}
return count, rows.Err()
}