package services import ( "bytes" "context" "database/sql" "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "strings" "sync" "time" "goalfymax-admin/internal/config" "goalfymax-admin/internal/notifier" "goalfymax-admin/internal/storage" pkgredis "goalfymax-admin/pkg/redis" mysqldriver "github.com/go-sql-driver/mysql" "github.com/google/uuid" "github.com/jackc/pgconn" "github.com/shopspring/decimal" "gorm.io/gorm" ) type PageResult struct { List []map[string]interface{} `json:"list"` Total int64 `json:"total"` } var ( usdScalingFactor = decimal.NewFromInt(100000000) lowBalanceThreshold = decimal.NewFromInt(30) // 默认值30 lowBalanceThresholdMutex sync.RWMutex ErrMcpUsageJobAlreadyProcessed = errors.New("mcp usage job already processed") ErrMcpUsageJobAlreadyInProgress = errors.New("mcp usage job already in progress") ErrModelTokenJobAlreadyProcessed = errors.New("model token job already processed") ErrModelTokenJobAlreadyInProgress = errors.New("model token job already in progress") ) // GetLowBalanceThreshold 获取低余额阈值(线程安全) func GetLowBalanceThreshold() decimal.Decimal { lowBalanceThresholdMutex.RLock() defer lowBalanceThresholdMutex.RUnlock() return lowBalanceThreshold } // RefreshLowBalanceThreshold 从系统配置刷新低余额阈值(定时任务执行前调用) func RefreshLowBalanceThreshold() { const configKey = "low_balance_threshold" const defaultValue = 30 // 创建系统配置服务实例 systemConfigService := NewSystemConfigService( storage.NewSystemConfigStorage(), nil, // 不使用 logger ) config, err := systemConfigService.GetByKey(configKey) if err != nil { // 配置不存在,使用默认值 lowBalanceThresholdMutex.Lock() lowBalanceThreshold = decimal.NewFromInt(defaultValue) lowBalanceThresholdMutex.Unlock() return } // 解析配置值 valueStr := strings.TrimSpace(config.Value) if valueStr == "" { // 值为空,使用默认值 lowBalanceThresholdMutex.Lock() lowBalanceThreshold = decimal.NewFromInt(defaultValue) lowBalanceThresholdMutex.Unlock() return } // 尝试解析为整数 valueInt, err := strconv.ParseInt(valueStr, 10, 64) if err != nil { // 解析失败,使用默认值 lowBalanceThresholdMutex.Lock() lowBalanceThreshold = decimal.NewFromInt(defaultValue) lowBalanceThresholdMutex.Unlock() return } // 设置配置值 lowBalanceThresholdMutex.Lock() lowBalanceThreshold = decimal.NewFromInt(valueInt) lowBalanceThresholdMutex.Unlock() } const ( jobStatusPending = "pending" jobStatusSuccess = "success" jobStatusFailed = "failed" mcpUsageJobName = "mcp_usage_hourly_settlement" modelTokenJobName = "model_token_usage_hourly_settlement" ) func toString(v interface{}) string { switch t := v.(type) { case []byte: return string(t) case string: return t default: return fmt.Sprintf("%v", v) } } func toInt(v interface{}) (int64, bool) { s := toString(v) if s == "" { return 0, false } n, err := strconv.ParseInt(s, 10, 64) if err != nil { return 0, false } return n, true } func toFloat(v interface{}) (float64, bool) { s := toString(v) if s == "" { return 0, false } f, err := strconv.ParseFloat(s, 64) if err != nil { return 0, false } return f, true } func toTimeISO(s interface{}) (string, bool) { str := toString(s) if str == "" { return "", false } // Try common layouts layouts := []string{ time.RFC3339, "2006-01-02 15:04:05", "2006-01-02T15:04:05", } for _, l := range layouts { if tm, err := time.ParseInLocation(l, str, time.Local); err == nil { return tm.Format(time.RFC3339), true } } return str, true } func normalizeSandboxRow(cols []string, vals []interface{}) map[string]interface{} { m := map[string]interface{}{} for i, c := range cols { switch c { case "duration_minutes", "unit_price_usd", "total_cost_usd": if f, ok := toFloat(vals[i]); ok { m[c] = f } else { m[c] = nil } case "total_cost_balance": if n, ok := toInt(vals[i]); ok { m[c] = n } else { m[c] = nil } case "created_at", "updated_at", "started_at", "released_at", "last_billed_at": if iso, ok := toTimeISO(vals[i]); ok { m[c] = iso } else { m[c] = nil } case "info": s := toString(vals[i]) var obj interface{} if err := json.Unmarshal([]byte(s), &obj); err == nil { m[c] = obj } else { m[c] = s } default: m[c] = toString(vals[i]) } } return m } func normalizeGenericInts(cols []string, vals []interface{}, intCols []string, timeCols []string) map[string]interface{} { m := map[string]interface{}{} isInt := map[string]struct{}{} for _, k := range intCols { isInt[k] = struct{}{} } isTime := map[string]struct{}{} for _, k := range timeCols { isTime[k] = struct{}{} } for i, c := range cols { if _, ok := isInt[c]; ok { if n, ok2 := toInt(vals[i]); ok2 { m[c] = n } else { m[c] = nil } continue } if _, ok := isTime[c]; ok { if iso, ok2 := toTimeISO(vals[i]); ok2 { m[c] = iso } else { m[c] = nil } continue } m[c] = toString(vals[i]) } return m } func parseBool(val interface{}) bool { switch v := val.(type) { case bool: return v case *bool: if v != nil { return *v } case int, int32, int64: if v, ok := toInt(val); ok { return v != 0 } case uint, uint32, uint64: if v, ok := toInt(val); ok { return v != 0 } case string: lower := strings.ToLower(strings.TrimSpace(v)) return lower == "true" || lower == "1" || lower == "t" } return false } type mcpUsageAggregation struct { Provider string `gorm:"column:provider"` Account sql.NullString `gorm:"column:account"` CostSum int64 `gorm:"column:cost_sum"` } type modelConfigInfo struct { ID uint64 Provider string ModelName string PriceRatio float64 } func applyMcpBalanceChange(tx *gorm.DB, providerID uuid.UUID, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) { var latest struct { Balance decimal.Decimal `gorm:"column:balance"` } prev := decimal.Zero err := tx.Table("mcp_account_balances"). Select("balance"). Where("provider_id = ?", providerID). Order("created_at DESC"). Take(&latest).Error if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return decimal.Zero, decimal.Zero, err } } else { prev = latest.Balance } newBalance := prev.Add(delta) record := map[string]interface{}{ "provider_id": providerID, "balance": newBalance, "currency": "USD", "remark": remark, } if err := tx.Table("mcp_account_balances").Create(record).Error; err != nil { return decimal.Zero, decimal.Zero, err } return prev, newBalance, nil } func findMcpProviderID(pg *gorm.DB, provider, account string) (uuid.UUID, error) { var result struct { ID uuid.UUID `gorm:"column:id"` } query := pg.Table("mcp_providers").Select("id").Where("provider = ?", provider) if account != "" { query = query.Where("account = ?", account) } else { query = query.Where("(account IS NULL OR account = '')") } err := query.Order("created_at DESC").Take(&result).Error if err != nil { return uuid.Nil, err } return result.ID, nil } // mcpProviderInfo 存储 provider 信息用于内存查找 type mcpProviderInfo struct { ID uuid.UUID Provider string Account string // 空字符串表示 account 为空 FloatingRatio float64 } type mcpProviderCache struct { ID uuid.UUID FloatingRatio float64 } // loadMcpProvidersToMemory 加载所有 mcp_providers 记录到内存 // 构建 map: key = account, value = provider_id func loadMcpProvidersToMemory(pg *gorm.DB) (map[string]mcpProviderCache, error) { var providers []mcpProviderInfo err := pg.Table("mcp_providers"). Select("id, provider, COALESCE(account, '') AS account, COALESCE(floating_ratio, 1) AS floating_ratio"). Order("created_at DESC"). Find(&providers).Error if err != nil { return nil, err } // 构建 map: key = account, value = provider_id // 如果同一个 account 有多条记录,保留最新的(因为已经按 created_at DESC 排序) providerMap := make(map[string]mcpProviderCache) for _, p := range providers { key := p.Account // 如果该 key 已存在,保留最新的(因为已经按 created_at DESC 排序) if _, exists := providerMap[key]; !exists { ratio := p.FloatingRatio if ratio < 0 { ratio = 0 } ratio = 1 + ratio providerMap[key] = mcpProviderCache{ ID: p.ID, FloatingRatio: ratio, } } } return providerMap, nil } // findMcpProviderFromMemory 从内存 map 中查找 provider 信息 // 直接通过 account 查找,不使用 provider func findMcpProviderFromMemory(providerMap map[string]mcpProviderCache, provider, account string) (mcpProviderCache, error) { // 直接使用 account 作为 key if entry, found := providerMap[account]; found { return entry, nil } return mcpProviderCache{}, fmt.Errorf("未找到匹配的账号: account=%s", account) } func getLastProcessedWindowEnd(pg *gorm.DB, jobName string) (time.Time, error) { var result struct { WindowEnd time.Time `gorm:"column:window_end"` } err := pg.Table("mcp_usage_balance_job_runs"). Select("window_end"). Where("job_name = ? AND status = ?", jobName, jobStatusSuccess). Order("id DESC"). Limit(1). Take(&result).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return time.Time{}, nil // 返回零值表示没有记录 } return time.Time{}, err } return result.WindowEnd, nil } func createMcpUsageJobRun(pg *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) { type JobRun struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement"` JobName string `gorm:"column:job_name"` WindowStart time.Time `gorm:"column:window_start"` WindowEnd time.Time `gorm:"column:window_end"` Status string `gorm:"column:status"` RecordsProcessed int `gorm:"column:records_processed"` TotalCost decimal.Decimal `gorm:"column:total_cost"` ErrorMessage string `gorm:"column:error_message"` CreatedAt time.Time `gorm:"column:created_at"` UpdatedAt time.Time `gorm:"column:updated_at"` } jobRun := JobRun{ JobName: jobName, WindowStart: windowStart, WindowEnd: windowEnd, Status: jobStatusPending, RecordsProcessed: 0, TotalCost: decimal.Zero, ErrorMessage: "", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := pg.Table("mcp_usage_balance_job_runs").Create(&jobRun).Error; err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == "23505" { var existing struct { ID int64 `gorm:"column:id"` Status string `gorm:"column:status"` } if lookupErr := pg.Table("mcp_usage_balance_job_runs"). Select("id, status"). Where("job_name = ? AND window_start = ?", jobName, windowStart). Take(&existing).Error; lookupErr == nil { switch existing.Status { case jobStatusPending: return existing.ID, ErrMcpUsageJobAlreadyInProgress case jobStatusSuccess: return existing.ID, ErrMcpUsageJobAlreadyProcessed default: return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status) } } } return 0, err } return jobRun.ID, nil } func updateMcpUsageJobRun(pg *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error { updateData := map[string]interface{}{ "status": status, "records_processed": records, "total_cost": total, "error_message": errorMessage, "updated_at": time.Now(), } return pg.Table("mcp_usage_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error } func getLastProcessedModelTokenWindowEnd(db *gorm.DB, jobName string) (time.Time, error) { var result struct { WindowEnd time.Time `gorm:"column:window_end"` } err := db.Table("model_token_balance_job_runs"). Select("window_end"). Where("job_name = ? AND status = ?", jobName, jobStatusSuccess). Order("id DESC"). Limit(1). Take(&result).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return time.Time{}, nil } return time.Time{}, err } return result.WindowEnd, nil } func createModelTokenJobRun(db *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) { type jobRun struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement"` JobName string `gorm:"column:job_name"` WindowStart time.Time `gorm:"column:window_start"` WindowEnd time.Time `gorm:"column:window_end"` Status string `gorm:"column:status"` RecordsProcessed int `gorm:"column:records_processed"` TotalCost decimal.Decimal ErrorMessage string `gorm:"column:error_message"` CreatedAt time.Time `gorm:"column:created_at"` UpdatedAt time.Time `gorm:"column:updated_at"` } run := jobRun{ JobName: jobName, WindowStart: windowStart, WindowEnd: windowEnd, Status: jobStatusPending, RecordsProcessed: 0, TotalCost: decimal.Zero, ErrorMessage: "", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := db.Table("model_token_balance_job_runs").Create(&run).Error; err != nil { var mysqlErr *mysqldriver.MySQLError if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { var existing struct { ID int64 `gorm:"column:id"` Status string `gorm:"column:status"` } if lookupErr := db.Table("model_token_balance_job_runs"). Select("id, status"). Where("job_name = ? AND window_start = ?", jobName, windowStart). Order("id DESC"). Take(&existing).Error; lookupErr == nil { switch existing.Status { case jobStatusPending: return existing.ID, ErrModelTokenJobAlreadyInProgress case jobStatusSuccess: return existing.ID, ErrModelTokenJobAlreadyProcessed default: return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status) } } return 0, ErrModelTokenJobAlreadyProcessed } return 0, err } return run.ID, nil } func updateModelTokenJobRun(db *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error { updateData := map[string]interface{}{ "status": status, "records_processed": records, "total_cost": total, "error_message": errorMessage, "updated_at": time.Now(), } return db.Table("model_token_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error } func loadModelConfigsToMemory(db *gorm.DB) (map[string]modelConfigInfo, error) { var configs []struct { ID uint64 `gorm:"column:id"` Provider string `gorm:"column:provider"` ModelName string `gorm:"column:model_name"` PriceRatio float64 `gorm:"column:price_ratio"` } err := db.Table("gw_model_config_v2"). Select("id, provider, model_name, price_ratio"). Where("deleted_at IS NULL"). Find(&configs).Error if err != nil { return nil, err } result := make(map[string]modelConfigInfo, len(configs)) for _, cfg := range configs { key := makeModelConfigKey(cfg.Provider, cfg.ModelName) if _, exists := result[key]; !exists { result[key] = modelConfigInfo{ ID: cfg.ID, Provider: strings.TrimSpace(cfg.Provider), ModelName: strings.TrimSpace(cfg.ModelName), PriceRatio: cfg.PriceRatio, } } } return result, nil } func makeModelConfigKey(provider, model string) string { return strings.ToLower(strings.TrimSpace(provider)) + "|" + strings.ToLower(strings.TrimSpace(model)) } func AdjustMcpAccountBalance(providerID string, newBalance float64, operatorName, remark string) error { db := storage.GetPG() if db == nil { return fmt.Errorf("PostgreSQL未初始化") } if newBalance < 0 { return fmt.Errorf("余额不能为负数") } uuidVal, err := uuid.Parse(providerID) if err != nil { return fmt.Errorf("无效的provider_id: %w", err) } target := decimal.NewFromFloat(newBalance).Round(8) current := decimal.Zero var latest struct { Balance decimal.Decimal `gorm:"column:balance"` } err = db.Table("mcp_account_balances"). Select("balance"). Where("provider_id = ?", uuidVal). Order("created_at DESC"). Take(&latest).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("获取当前余额失败: %w", err) } if err == nil { current = latest.Balance } delta := target.Sub(current) if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) { // 无需更新,但仍返回成功 return nil } tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开启事务失败: %w", tx.Error) } fullRemark := remark if strings.TrimSpace(fullRemark) == "" { fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName) } else { fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName) } _, _, err = applyMcpBalanceChange(tx, uuidVal, delta, fullRemark) if err != nil { tx.Rollback() return err } if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } func CreateMcpAccountBalanceRecord(providerID string, balance float64, operatorName, remark string) error { db := storage.GetPG() if db == nil { return fmt.Errorf("PostgreSQL未初始化") } if balance < 0 { return fmt.Errorf("余额不能为负数") } uuidVal, err := uuid.Parse(providerID) if err != nil { return fmt.Errorf("无效的provider_id: %w", err) } var exists int64 if err := db.Table("mcp_account_balances").Where("provider_id = ?", uuidVal).Limit(1).Count(&exists).Error; err != nil { return fmt.Errorf("查询余额失败: %w", err) } if exists > 0 { return fmt.Errorf("余额记录已存在,请使用调整余额功能") } target := decimal.NewFromFloat(balance).Round(8) tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开启事务失败: %w", tx.Error) } fullRemark := remark if strings.TrimSpace(fullRemark) == "" { fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName) } else { fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName) } _, _, err = applyMcpBalanceChange(tx, uuidVal, target, fullRemark) if err != nil { tx.Rollback() return err } if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } func AdjustModelAccountBalance(account string, newBalance float64, operatorName, remark string) error { db := storage.DB if db == nil { return fmt.Errorf("MySQL未初始化") } if newBalance < 0 { return fmt.Errorf("余额不能为负数") } target := decimal.NewFromFloat(newBalance).Round(8) current := decimal.Zero var latest struct { Balance decimal.Decimal `gorm:"column:balance"` } err := db.Table("gw_model_account_balances"). Select("balance"). Where("account = ?", account). Order("created_at DESC"). Take(&latest).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("获取当前余额失败: %w", err) } if err == nil { current = latest.Balance } delta := target.Sub(current) if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) { return nil } tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开启事务失败: %w", tx.Error) } fullRemark := remark if strings.TrimSpace(fullRemark) == "" { fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName) } else { fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName) } _, _, err = applyModelAccountBalanceChange(tx, account, delta, fullRemark) if err != nil { tx.Rollback() return err } if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } func CreateModelAccountBalanceRecord(account string, balance float64, operatorName, remark string) error { db := storage.DB if db == nil { return fmt.Errorf("MySQL未初始化") } if balance < 0 { return fmt.Errorf("余额不能为负数") } var exists int64 if err := db.Table("gw_model_account_balances").Where("account = ?", account).Limit(1).Count(&exists).Error; err != nil { return fmt.Errorf("查询余额失败: %w", err) } if exists > 0 { return fmt.Errorf("余额记录已存在,请使用调整余额功能") } target := decimal.NewFromFloat(balance).Round(8) tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开启事务失败: %w", tx.Error) } fullRemark := remark if strings.TrimSpace(fullRemark) == "" { fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName) } else { fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName) } _, _, err := applyModelAccountBalanceChange(tx, account, target, fullRemark) if err != nil { tx.Rollback() return err } if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } func ListSandboxRecords(offset, limit int, userID, projectID, start, end string) (*PageResult, error) { var total int64 q := storage.DB.Table("sb_sandbox_record") if userID != "" { q = q.Where("user_id = ?", userID) } if projectID != "" { q = q.Where("project_id = ?", projectID) } if start != "" && end != "" { q = q.Where("created_at BETWEEN ? AND ?", start, end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count sandbox records: %w", err) } rows, err := q.Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query sandbox records: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } list = append(list, normalizeSandboxRow(cols, vals)) } return &PageResult{List: list, Total: total}, nil } func ListTokenUsages(offset, limit int, userID, projectID, startDay, endDay string) (*PageResult, error) { var total int64 q := storage.DB.Table("gw_token_usages") if userID != "" { q = q.Where("user_id = ?", userID) } if projectID != "" { q = q.Where("project_id = ?", projectID) } if startDay != "" && endDay != "" { q = q.Where("day BETWEEN ? AND ?", startDay, endDay) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count token usages: %w", err) } rows, err := q.Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query token usages: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id", "hour", "cost", "prompt_token", "completion_token", "cache_create_token", "cache_read_token"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } list = append(list, normalizeGenericInts(cols, vals, intCols, nil)) } return &PageResult{List: list, Total: total}, nil } func ListMCPUsages(offset, limit int, userID, projectID, start, end string) (*PageResult, error) { var total int64 q := storage.DB.Table("mcp_invoke_usages") if userID != "" { q = q.Where("user_id = ?", userID) } if projectID != "" { q = q.Where("project_id = ?", projectID) } if start != "" && end != "" { q = q.Where("created_at BETWEEN ? AND ?", start, end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count mcp usages: %w", err) } rows, err := q.Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query mcp usages: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id", "hour", "cost", "call_count"} timeCols := []string{"created_at", "updated_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols)) } return &PageResult{List: list, Total: total}, nil } func ListTransactionLogs(offset, limit int, userID, orderID, txType, status, start, end string) (*PageResult, error) { var total int64 q := storage.DB.Table("m_transaction_logs") if userID != "" { q = q.Where("user_id = ?", userID) } if orderID != "" { q = q.Where("order_id = ?", orderID) } if txType != "" { q = q.Where("type = ?", txType) } if status != "" { q = q.Where("status = ?", status) } if start != "" && end != "" { q = q.Where("created_at BETWEEN ? AND ?", start, end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count transaction logs: %w", err) } rows, err := q.Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query transaction logs: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id", "amount", "balance_before", "balance_after"} timeCols := []string{"created_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols)) } return &PageResult{List: list, Total: total}, nil } // ListPaymentRecords 查询充值支付记录 func ListPaymentRecords(offset, limit int, userID, orderID, paypalOrderID, status, refundStatus, payerEmail, start, end string) (*PageResult, error) { var total int64 q := storage.DB.Table("m_payment_records") if userID != "" { q = q.Where("user_id = ?", userID) } if orderID != "" { q = q.Where("order_id = ?", orderID) } if paypalOrderID != "" { q = q.Where("paypal_order_id = ?", paypalOrderID) } if status != "" { q = q.Where("status = ?", status) } if refundStatus != "" { q = q.Where("refund_status = ?", refundStatus) } if payerEmail != "" { q = q.Where("payer_email LIKE ?", "%"+payerEmail+"%") } if start != "" && end != "" { q = q.Where("created_at BETWEEN ? AND ?", start, end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count payment records: %w", err) } rows, err := q.Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query payment records: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id", "amount", "refunded_amount"} timeCols := []string{"created_at", "updated_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols)) } return &PageResult{List: list, Total: total}, nil } // RefundPaymentRecord 退款支付记录 func RefundPaymentRecord(orderID, paypalCaptureID string, amount *int64) error { cfg := config.GetConfig() payBaseURL := cfg.Pay.BaseURL timeout := time.Duration(cfg.Pay.Timeout) * time.Second // 构建请求体 reqBody := make(map[string]interface{}) if orderID != "" { reqBody["order_id"] = orderID } if paypalCaptureID != "" { reqBody["paypal_capture_id"] = paypalCaptureID } if amount != nil { reqBody["amount"] = *amount } jsonData, err := json.Marshal(reqBody) if err != nil { return fmt.Errorf("序列化请求失败: %w", err) } // 创建HTTP请求 req, err := http.NewRequest("POST", payBaseURL+"/api/refund", bytes.NewBuffer(jsonData)) if err != nil { return fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") // 发送请求 client := &http.Client{Timeout: timeout} resp, err := client.Do(req) if err != nil { return fmt.Errorf("发送退款请求失败: %w", err) } defer resp.Body.Close() // 读取响应 body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("退款接口返回错误: %d, 响应: %s", resp.StatusCode, string(body)) } // 读取订单基础信息(一次查询拿到 user_id 和 amount) var rec struct { UserID string `json:"user_id"` Amount int64 `json:"amount"` } if err := storage.DB.Table("m_payment_records").Select("user_id, amount").Where("order_id = ?", orderID).Scan(&rec).Error; err != nil { return nil // 不阻断主流程 } // 确定应扣减的退款金额(单位:分) var refundCents int64 if amount != nil { refundCents = *amount } else { refundCents = rec.Amount } if refundCents <= 0 { return nil } // 计算递减值(Redis单位) decrBy := refundCents * 1_000_000 // 1 cent = 1e6 in Redis unit // 连接 Redis 并扣减余额 rclient, err := pkgredis.NewClient(cfg.Redis) if err != nil { return nil // 忽略Redis失败 } defer rclient.Close() ctx := context.Background() key := fmt.Sprintf("GW:QU_%s", rec.UserID) _, _ = rclient.Rdb.DecrBy(ctx, key, decrBy).Result() return nil } // ListMcpAccountRechargeRecords 查询MCP账号充值记录 func ListMcpAccountRechargeRecords(offset, limit int, provider, account, start, end string) (*PageResult, error) { db := storage.GetPG() if db == nil { return nil, fmt.Errorf("PostgreSQL未初始化") } var total int64 q := db.Table("mcp_account_recharge_records"). Select(` mcp_account_recharge_records.*, mcp_providers.provider, mcp_providers.account `). Joins("LEFT JOIN mcp_providers ON mcp_account_recharge_records.provider_id = mcp_providers.id") if provider != "" { q = q.Where("mcp_providers.provider ILIKE ?", "%"+provider+"%") } if account != "" { q = q.Where("mcp_providers.account ILIKE ?", "%"+account+"%") } if start != "" { q = q.Where("mcp_account_recharge_records.recharge_date >= ?", start) } if end != "" { q = q.Where("mcp_account_recharge_records.recharge_date <= ?", end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count mcp account recharge records: %w", err) } rows, err := q.Order("mcp_account_recharge_records.created_at DESC"). Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query mcp account recharge records: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) floatCols := []string{"amount"} timeCols := []string{"created_at", "updated_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } // 处理数据类型 m := map[string]interface{}{} isFloat := map[string]struct{}{} for _, k := range floatCols { isFloat[k] = struct{}{} } isTime := map[string]struct{}{} for _, k := range timeCols { isTime[k] = struct{}{} } for i, c := range cols { if _, ok := isFloat[c]; ok { if f, ok2 := toFloat(vals[i]); ok2 { m[c] = f } else { m[c] = nil } continue } if _, ok := isTime[c]; ok { if iso, ok2 := toTimeISO(vals[i]); ok2 { m[c] = iso } else { m[c] = nil } continue } m[c] = toString(vals[i]) } list = append(list, m) } return &PageResult{List: list, Total: total}, nil } // CreateMcpAccountRechargeRecord 创建MCP账号充值记录 func CreateMcpAccountRechargeRecord(providerID string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error { db := storage.GetPG() if db == nil { return fmt.Errorf("PostgreSQL未初始化") } // 解析provider_id为UUID providerUUID, err := uuid.Parse(providerID) if err != nil { return fmt.Errorf("无效的provider_id: %w", err) } // 解析日期 date, err := time.Parse("2006-01-02", rechargeDate) if err != nil { return fmt.Errorf("无效的日期格式: %w", err) } // 处理operator_id var operatorUUID *uuid.UUID if operatorID != nil { var opID uuid.UUID switch v := operatorID.(type) { case string: opID, err = uuid.Parse(v) if err != nil { // 如果不是UUID格式,可能是数字ID,设为nil operatorUUID = nil } else { operatorUUID = &opID } case uuid.UUID: operatorUUID = &v default: operatorUUID = nil } } // 开始事务 tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开始事务失败: %w", tx.Error) } defer func() { if r := recover(); r != nil { tx.Rollback() } }() // 1. 创建充值记录 record := map[string]interface{}{ "provider_id": providerUUID, "amount": amount, "currency": "USD", "recharge_date": date, "operator_id": operatorUUID, "operator_name": operatorName, "remark": remark, } if err := tx.Table("mcp_account_recharge_records").Create(record).Error; err != nil { tx.Rollback() return fmt.Errorf("创建充值记录失败: %w", err) } amountDecimal := decimal.NewFromFloat(amount).Round(8) remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName) _, _, err = applyMcpBalanceChange(tx, providerUUID, amountDecimal, remarkText) if err != nil { tx.Rollback() return fmt.Errorf("创建余额记录失败: %w", err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } // UpdateMcpAccountRechargeRecord 更新MCP账号充值记录 func UpdateMcpAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error { db := storage.GetPG() if db == nil { return fmt.Errorf("PostgreSQL未初始化") } recordID, err := uuid.Parse(id) if err != nil { return fmt.Errorf("无效的记录ID: %w", err) } updates := map[string]interface{}{} if amount != nil { updates["amount"] = *amount } if rechargeDate != nil { date, err := time.Parse("2006-01-02", *rechargeDate) if err != nil { return fmt.Errorf("无效的日期格式: %w", err) } updates["recharge_date"] = date } if remark != nil { updates["remark"] = *remark } if len(updates) == 0 { return fmt.Errorf("没有需要更新的字段") } updates["updated_at"] = time.Now() return db.Table("mcp_account_recharge_records"). Where("id = ?", recordID). Updates(updates).Error } // DeleteMcpAccountRechargeRecord 删除MCP账号充值记录 func DeleteMcpAccountRechargeRecord(id string) error { db := storage.GetPG() if db == nil { return fmt.Errorf("PostgreSQL未初始化") } recordID, err := uuid.Parse(id) if err != nil { return fmt.Errorf("无效的记录ID: %w", err) } return db.Table("mcp_account_recharge_records"). Where("id = ?", recordID). Delete(nil).Error } // GetMcpProviderAccounts 获取MCP账号列表(用于下拉选择) func GetMcpProviderAccounts(status string, isUsed *bool) ([]map[string]interface{}, error) { db := storage.GetPG() if db == nil { return nil, fmt.Errorf("PostgreSQL未初始化") } q := db.Table("mcp_providers"). Select("id, provider, account, status, is_used") if status != "" { q = q.Where("status = ?", status) } if isUsed != nil { q = q.Where("is_used = ?", *isUsed) } var list []map[string]interface{} rows, err := q.Order("provider, account").Rows() if err != nil { return nil, fmt.Errorf("query mcp providers: %w", err) } defer rows.Close() cols, _ := rows.Columns() for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } m := map[string]interface{}{} for i, c := range cols { m[c] = toString(vals[i]) } list = append(list, m) } return list, nil } // GetMcpAccountLatestBalance 获取账户最新余额 func GetMcpAccountLatestBalance(providerID string) (map[string]interface{}, error) { db := storage.GetPG() if db == nil { return nil, fmt.Errorf("PostgreSQL未初始化") } providerUUID, err := uuid.Parse(providerID) if err != nil { return nil, fmt.Errorf("无效的provider_id: %w", err) } var balance map[string]interface{} err = db.Table("mcp_account_balances"). Select(` mcp_account_balances.*, mcp_providers.provider, mcp_providers.account `). Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id"). Where("mcp_account_balances.provider_id = ?", providerUUID). Order("mcp_account_balances.created_at DESC"). Limit(1). Find(&balance).Error if err != nil { // 如果没有记录,返回空map而不是错误 return map[string]interface{}{}, nil } return balance, nil } // GetMcpAccountBalances 获取所有账户最新余额列表 func GetMcpAccountBalances() ([]map[string]interface{}, error) { db := storage.GetPG() if db == nil { return nil, fmt.Errorf("PostgreSQL未初始化") } // 先获取每个账户最新的创建时间 var latestRecords []struct { ProviderID uuid.UUID MaxCreated time.Time } err := db.Table("mcp_account_balances"). Select("provider_id, MAX(created_at) as max_created"). Group("provider_id"). Find(&latestRecords).Error if err != nil { return nil, fmt.Errorf("查询余额失败: %w", err) } // 为每个账户获取最新余额 balances := make([]map[string]interface{}, 0) for _, lr := range latestRecords { var balance map[string]interface{} err = db.Table("mcp_account_balances"). Select(` mcp_account_balances.*, mcp_providers.provider, mcp_providers.account `). Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id"). Where("mcp_account_balances.provider_id = ? AND mcp_account_balances.created_at = ?", lr.ProviderID, lr.MaxCreated). Limit(1). Find(&balance).Error if err == nil && len(balance) > 0 { balances = append(balances, balance) } } // 处理数据类型 result := make([]map[string]interface{}, 0) for _, bal := range balances { m := make(map[string]interface{}) for k, v := range bal { if k == "balance" { if f, ok := toFloat(v); ok { m[k] = f } else { m[k] = v } } else if k == "created_at" { if iso, ok := toTimeISO(v); ok { m[k] = iso } else { m[k] = toString(v) } } else { m[k] = toString(v) } } result = append(result, m) } return result, nil } // RunMcpUsageBalanceJob 汇总MCP调用费用并更新余额 // 根据上次处理的时间窗口自动计算本次需要处理的时间范围 func RunMcpUsageBalanceJob() error { // 定时任务执行前刷新低余额阈值配置 RefreshLowBalanceThreshold() mysqlDB := storage.DB if mysqlDB == nil { return fmt.Errorf("MySQL未初始化") } pg := storage.GetPG() if pg == nil { return fmt.Errorf("PostgreSQL未初始化") } // 获取上次成功处理的窗口结束时间 lastWindowEnd, err := getLastProcessedWindowEnd(pg, mcpUsageJobName) if err != nil { return fmt.Errorf("获取上次处理时间失败: %w", err) } // 计算当前时间前一个小时的整点时间 now := time.Now().UTC() currentWindowEnd := now.Truncate(time.Hour) var windowStart, windowEnd time.Time if lastWindowEnd.IsZero() { // 如果没有历史记录,只处理上一个小时的数据 windowEnd = currentWindowEnd windowStart = windowEnd.Add(-time.Hour) } else { // 如果有历史记录,从上次处理的时间到当前时间前一个小时 windowStart = lastWindowEnd windowEnd = currentWindowEnd } // 如果窗口为空(例如上次处理到10:00,当前也是10:00),直接返回 if !windowStart.Before(windowEnd) { return nil } windowStart = windowStart.UTC() windowEnd = windowEnd.UTC() // 加载所有 mcp_providers 记录到内存 providerMap, err := loadMcpProvidersToMemory(pg) if err != nil { return fmt.Errorf("加载账号信息失败: %w", err) } runID, err := createMcpUsageJobRun(pg, mcpUsageJobName, windowStart, windowEnd) if err != nil { if errors.Is(err, ErrMcpUsageJobAlreadyProcessed) || errors.Is(err, ErrMcpUsageJobAlreadyInProgress) { return nil } return fmt.Errorf("创建任务记录失败: %w", err) } var usageRows []mcpUsageAggregation err = mysqlDB.Table("mcp_invoke_usages"). Select("provider, account, SUM(cost) AS cost_sum"). Where("updated_at >= ? AND updated_at < ?", windowStart, windowEnd). Group("provider, account"). Having("SUM(cost) <> 0"). Find(&usageRows).Error if err != nil { updateErr := updateMcpUsageJobRun(pg, runID, jobStatusFailed, 0, decimal.Zero, err.Error()) if updateErr != nil { return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err) } return fmt.Errorf("聚合调用费用失败: %w", err) } if len(usageRows) == 0 { return updateMcpUsageJobRun(pg, runID, jobStatusSuccess, 0, decimal.Zero, "") } var ( recordsProcessed int totalCost = decimal.Zero hadFailure bool errorMessages []string ) for _, row := range usageRows { provider := strings.TrimSpace(row.Provider) account := "" if row.Account.Valid { account = strings.TrimSpace(row.Account.String) } costCents := row.CostSum costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor) detailStatus := jobStatusSuccess detailError := "" var ( providerEntry mcpProviderCache providerID uuid.UUID actualCostUSD = costUSD floatingRatio float64 = 1 ) if account == "" { detailStatus = jobStatusFailed detailError = "account为空" } else { if providerEntry, err = findMcpProviderFromMemory(providerMap, provider, account); err != nil { detailStatus = jobStatusFailed detailError = fmt.Sprintf("查找账号失败: %v", err) } else { providerID = providerEntry.ID floatingRatio = providerEntry.FloatingRatio } } if detailStatus == jobStatusSuccess { if floatingRatio <= 0 { floatingRatio = 1 } if floatingRatio != 1 { actualCostUSD = costUSD.Div(decimal.NewFromFloat(floatingRatio)) } actualCostUSD = actualCostUSD.Round(8) if actualCostUSD.LessThanOrEqual(decimal.Zero) { detailStatus = jobStatusFailed detailError = "结算金额为0" } } var previousBalance decimal.Decimal var newBalance decimal.Decimal if detailStatus == jobStatusSuccess { tx := pg.Begin() if tx.Error != nil { detailStatus = jobStatusFailed detailError = fmt.Sprintf("开启事务失败: %v", tx.Error) } else { remark := fmt.Sprintf( "自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)", windowStart.Format("2006-01-02 15:04"), windowEnd.Format("15:04"), costUSD.StringFixed(8), floatingRatio, actualCostUSD.StringFixed(8), ) previousBalance, newBalance, err = applyMcpBalanceChange(tx, providerID, actualCostUSD.Neg(), remark) if err != nil { tx.Rollback() detailStatus = jobStatusFailed detailError = fmt.Sprintf("扣减余额失败: %v", err) } else { if commitErr := tx.Commit().Error; commitErr != nil { detailStatus = jobStatusFailed detailError = fmt.Sprintf("提交事务失败: %v", commitErr) } } } } if detailStatus == jobStatusSuccess { recordsProcessed++ totalCost = totalCost.Add(actualCostUSD) threshold := GetLowBalanceThreshold() if newBalance.LessThan(threshold) { notifier.NotifyMcpLowBalance(provider, account, newBalance, threshold) } } else { hadFailure = true errorMessages = append(errorMessages, fmt.Sprintf("%s/%s: %s", provider, account, detailError)) } detailData := map[string]interface{}{ "run_id": runID, "provider": provider, "account": account, "cost_cents": costCents, "cost_usd": actualCostUSD, "status": detailStatus, "created_at": time.Now(), "error_message": detailError, } if providerID != uuid.Nil { detailData["provider_id"] = providerID } if detailStatus == jobStatusSuccess { detailData["previous_balance"] = previousBalance detailData["new_balance"] = newBalance } if err := pg.Table("mcp_usage_balance_job_details").Create(detailData).Error; err != nil { hadFailure = true errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s: %v", provider, account, err)) } } finalStatus := jobStatusSuccess finalError := "" if hadFailure { finalStatus = jobStatusFailed finalError = strings.Join(errorMessages, "; ") if len(finalError) > 1000 { finalError = finalError[:1000] } } if err := updateMcpUsageJobRun(pg, runID, finalStatus, recordsProcessed, totalCost, finalError); err != nil { return fmt.Errorf("更新任务记录失败: %w", err) } if hadFailure { return fmt.Errorf("部分账户结算失败: %s", finalError) } return nil } // RunModelTokenBalanceJob 汇总模型Token消费并扣减余额 func RunModelTokenBalanceJob() error { // 定时任务执行前刷新低余额阈值配置 RefreshLowBalanceThreshold() mysqlDB := storage.DB if mysqlDB == nil { return fmt.Errorf("MySQL未初始化") } lastWindowEnd, err := getLastProcessedModelTokenWindowEnd(mysqlDB, modelTokenJobName) if err != nil { return fmt.Errorf("获取上次处理时间失败: %w", err) } now := time.Now().UTC() currentWindowEnd := now.Truncate(time.Hour) var windowStart, windowEnd time.Time if lastWindowEnd.IsZero() { windowEnd = currentWindowEnd windowStart = windowEnd.Add(-time.Hour) } else { windowStart = lastWindowEnd windowEnd = currentWindowEnd } if !windowStart.Before(windowEnd) { return nil } windowStart = windowStart.UTC() windowEnd = windowEnd.UTC() runID, err := createModelTokenJobRun(mysqlDB, modelTokenJobName, windowStart, windowEnd) if err != nil { if errors.Is(err, ErrModelTokenJobAlreadyProcessed) || errors.Is(err, ErrModelTokenJobAlreadyInProgress) { return nil } return fmt.Errorf("创建任务记录失败: %w", err) } configMap, err := loadModelConfigsToMemory(mysqlDB) if err != nil { updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error()) if updateErr != nil { return fmt.Errorf("加载模型配置失败: %v; 原始错误: %w", updateErr, err) } return fmt.Errorf("加载模型配置失败: %w", err) } startDay := windowStart.Format("2006-01-02") startHour := windowStart.Hour() endDay := windowEnd.Format("2006-01-02") endHour := windowEnd.Hour() query := mysqlDB.Table("gw_token_usages"). Select("provider, account, model, SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) AS total_cost"). Where("(day > ?) OR (day = ? AND hour >= ?)", startDay, startDay, startHour). Where("(day < ?) OR (day = ? AND hour < ?)", endDay, endDay, endHour). Group("provider, account, model"). Having("SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) <> 0") rows, err := query.Rows() if err != nil { updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error()) if updateErr != nil { return fmt.Errorf("query error update failed: %v; 原始错误: %w", updateErr, err) } return fmt.Errorf("聚合Token消费失败: %w", err) } defer rows.Close() hasRows := false var ( recordsProcessed int totalCostUSD = decimal.Zero hadFailure bool errorMessages []string ) for rows.Next() { hasRows = true var ( provider string account sql.NullString model string totalCost sql.NullInt64 ) if scanErr := rows.Scan(&provider, &account, &model, &totalCost); scanErr != nil { hadFailure = true errorMsg := fmt.Sprintf("扫描聚合结果失败: %v", scanErr) errorMessages = append(errorMessages, errorMsg) updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, errorMsg) if updateErr != nil { return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, scanErr) } return fmt.Errorf("扫描聚合结果失败: %w", scanErr) } if !totalCost.Valid { continue } costCents := totalCost.Int64 if costCents == 0 { continue } provider = strings.TrimSpace(provider) model = strings.TrimSpace(model) accountStr := "" if account.Valid { accountStr = strings.TrimSpace(account.String) } costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor) detailStatus := jobStatusSuccess detailError := "" var ( modelConfigID uint64 previousBalance decimal.Decimal newBalance decimal.Decimal ) configInfo, exists := configMap[makeModelConfigKey(provider, model)] if !exists { detailStatus = jobStatusFailed detailError = "未找到匹配的模型配置" } else { modelConfigID = configInfo.ID } priceRatio := configInfo.PriceRatio if priceRatio <= 0 { priceRatio = 1 } actualCostUSD := costUSD if detailStatus == jobStatusSuccess { if priceRatio != 1 { actualCostUSD = costUSD.Div(decimal.NewFromFloat(priceRatio)) } actualCostUSD = actualCostUSD.Round(8) if actualCostUSD.LessThanOrEqual(decimal.Zero) { detailStatus = jobStatusFailed detailError = "结算金额为0" } } if detailStatus == jobStatusSuccess { tx := mysqlDB.Begin() if tx.Error != nil { detailStatus = jobStatusFailed detailError = fmt.Sprintf("开启事务失败: %v", tx.Error) } else { remark := fmt.Sprintf( "自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)", windowStart.Format("2006-01-02 15:04"), windowEnd.Format("15:04"), costUSD.StringFixed(8), priceRatio, actualCostUSD.StringFixed(8), ) previousBalance, newBalance, err = applyModelAccountBalanceChange(tx, accountStr, actualCostUSD.Neg(), remark) if err != nil { tx.Rollback() detailStatus = jobStatusFailed detailError = fmt.Sprintf("扣减余额失败: %v", err) } else if commitErr := tx.Commit().Error; commitErr != nil { detailStatus = jobStatusFailed detailError = fmt.Sprintf("提交事务失败: %v", commitErr) } } } if detailStatus == jobStatusSuccess { recordsProcessed++ totalCostUSD = totalCostUSD.Add(actualCostUSD) threshold := GetLowBalanceThreshold() if newBalance.LessThan(threshold) { notifier.NotifyModelLowBalance(provider, accountStr, model, newBalance, threshold) } } else { hadFailure = true errorMessages = append(errorMessages, fmt.Sprintf("%s/%s/%s: %s", provider, accountStr, model, detailError)) } detailData := map[string]interface{}{ "run_id": runID, "provider": provider, "account": accountStr, "model": model, "total_cost_cents": costCents, "total_cost_usd": actualCostUSD, "status": detailStatus, "created_at": time.Now(), "error_message": detailError, } if modelConfigID != 0 { detailData["model_config_id"] = modelConfigID } if detailStatus == jobStatusSuccess { detailData["previous_balance"] = previousBalance detailData["new_balance"] = newBalance } if detailErr := mysqlDB.Table("model_token_balance_job_details").Create(detailData).Error; detailErr != nil { hadFailure = true errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s/%s: %v", provider, accountStr, model, detailErr)) } } if err := rows.Err(); err != nil { updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, err.Error()) if updateErr != nil { return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err) } return fmt.Errorf("遍历聚合结果失败: %w", err) } if !hasRows { return updateModelTokenJobRun(mysqlDB, runID, jobStatusSuccess, 0, decimal.Zero, "") } finalStatus := jobStatusSuccess finalError := "" if hadFailure { finalStatus = jobStatusFailed finalError = strings.Join(errorMessages, "; ") if len(finalError) > 1000 { finalError = finalError[:1000] } } if err := updateModelTokenJobRun(mysqlDB, runID, finalStatus, recordsProcessed, totalCostUSD, finalError); err != nil { return fmt.Errorf("更新任务记录失败: %w", err) } if hadFailure { return fmt.Errorf("部分模型结算失败: %s", finalError) } return nil } // GetMcpAccountBalanceHistory 获取账户余额历史记录 func GetMcpAccountBalanceHistory(providerID string, start, end string) ([]map[string]interface{}, error) { db := storage.GetPG() if db == nil { return nil, fmt.Errorf("PostgreSQL未初始化") } providerUUID, err := uuid.Parse(providerID) if err != nil { return nil, fmt.Errorf("无效的provider_id: %w", err) } q := db.Table("mcp_account_balances"). Select(` mcp_account_balances.*, mcp_providers.provider, mcp_providers.account `). Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id"). Where("mcp_account_balances.provider_id = ?", providerUUID) if start != "" { q = q.Where("mcp_account_balances.created_at >= ?", start) } if end != "" { q = q.Where("mcp_account_balances.created_at <= ?", end) } rows, err := q.Order("mcp_account_balances.created_at DESC").Rows() if err != nil { return nil, fmt.Errorf("查询余额历史失败: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) floatCols := []string{"balance"} timeCols := []string{"created_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } m := map[string]interface{}{} isFloat := map[string]struct{}{} for _, k := range floatCols { isFloat[k] = struct{}{} } isTime := map[string]struct{}{} for _, k := range timeCols { isTime[k] = struct{}{} } for i, c := range cols { if _, ok := isFloat[c]; ok { if f, ok2 := toFloat(vals[i]); ok2 { m[c] = f } else { m[c] = nil } continue } if _, ok := isTime[c]; ok { if iso, ok2 := toTimeISO(vals[i]); ok2 { m[c] = iso } else { m[c] = toString(vals[i]) } continue } m[c] = toString(vals[i]) } list = append(list, m) } return list, nil } // ========== 模型账号充值记录和余额管理 ========== // applyModelAccountBalanceChange 应用模型账号余额变更(追加式记录) func applyModelAccountBalanceChange(tx *gorm.DB, account string, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) { var latest struct { Balance decimal.Decimal `gorm:"column:balance"` } prev := decimal.Zero err := tx.Table("gw_model_account_balances"). Select("balance"). Where("account = ?", account). Order("created_at DESC"). Take(&latest).Error if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return decimal.Zero, decimal.Zero, err } } else { prev = latest.Balance } newBalance := prev.Add(delta) record := map[string]interface{}{ "account": account, "balance": newBalance, "currency": "USD", "remark": remark, "created_at": time.Now(), } if err := tx.Table("gw_model_account_balances").Create(record).Error; err != nil { return decimal.Zero, decimal.Zero, err } return prev, newBalance, nil } // ListModelAccountRechargeRecords 查询模型账号充值记录 func ListModelAccountRechargeRecords(offset, limit int, provider, modelName, start, end string) (*PageResult, error) { db := storage.DB if db == nil { return nil, fmt.Errorf("MySQL未初始化") } var total int64 q := db.Table("gw_model_account_recharge_records r"). Select(` r.*, gp.name AS provider `). Joins("LEFT JOIN gw_providers gp ON r.account = gp.account"). Where("r.deleted_at IS NULL") if provider != "" { q = q.Where("gp.name LIKE ?", "%"+provider+"%") } if start != "" { q = q.Where("r.recharge_date >= ?", start) } if end != "" { q = q.Where("r.recharge_date <= ?", end) } if err := q.Count(&total).Error; err != nil { return nil, fmt.Errorf("count model account recharge records: %w", err) } rows, err := q.Order("r.created_at DESC"). Offset(offset).Limit(limit).Rows() if err != nil { return nil, fmt.Errorf("query model account recharge records: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id", "operator_id"} floatCols := []string{"amount"} timeCols := []string{"created_at", "updated_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } // 需要处理float类型 m := normalizeGenericInts(cols, vals, intCols, timeCols) // 手动处理float字段 for i, c := range cols { for _, fc := range floatCols { if c == fc { if f, ok := toFloat(vals[i]); ok { m[c] = f } break } } } list = append(list, m) } return &PageResult{List: list, Total: total}, nil } // CreateModelAccountRechargeRecord 创建模型账号充值记录 func CreateModelAccountRechargeRecord(account string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error { db := storage.DB if db == nil { return fmt.Errorf("MySQL未初始化") } account = strings.TrimSpace(account) if account == "" { return fmt.Errorf("账号不能为空") } // 验证账号是否存在 var count int64 if err := db.Table("gw_providers").Where("account = ?", account).Count(&count).Error; err != nil { return fmt.Errorf("验证账号失败: %w", err) } if count == 0 { return fmt.Errorf("账号不存在: %s", account) } // 解析日期 date, err := time.Parse("2006-01-02", rechargeDate) if err != nil { return fmt.Errorf("无效的日期格式: %w", err) } // 处理operator_id var operatorIDUint *uint64 if operatorID != nil { var opID uint64 switch v := operatorID.(type) { case string: parsed, err := strconv.ParseUint(v, 10, 64) if err != nil { operatorIDUint = nil } else { opID = parsed operatorIDUint = &opID } case uint64: operatorIDUint = &v case int64: u := uint64(v) operatorIDUint = &u case int: u := uint64(v) operatorIDUint = &u default: operatorIDUint = nil } } // 开始事务 tx := db.Begin() if tx.Error != nil { return fmt.Errorf("开始事务失败: %w", tx.Error) } defer func() { if r := recover(); r != nil { tx.Rollback() } }() // 1. 创建充值记录 now := time.Now() record := map[string]interface{}{ "account": account, "amount": amount, "currency": "USD", "recharge_date": date, "operator_id": operatorIDUint, "operator_name": operatorName, "remark": remark, "created_at": now, "updated_at": now, } if err := tx.Table("gw_model_account_recharge_records").Create(record).Error; err != nil { tx.Rollback() return fmt.Errorf("创建充值记录失败: %w", err) } amountDecimal := decimal.NewFromFloat(amount).Round(8) remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName) _, _, err = applyModelAccountBalanceChange(tx, account, amountDecimal, remarkText) if err != nil { tx.Rollback() return fmt.Errorf("创建余额记录失败: %w", err) } // 提交事务 if err := tx.Commit().Error; err != nil { return fmt.Errorf("提交事务失败: %w", err) } return nil } // UpdateModelAccountRechargeRecord 更新模型账号充值记录 func UpdateModelAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error { db := storage.DB if db == nil { return fmt.Errorf("MySQL未初始化") } recordID, err := strconv.ParseUint(id, 10, 64) if err != nil { return fmt.Errorf("无效的记录ID: %w", err) } updates := map[string]interface{}{} if amount != nil { updates["amount"] = *amount } if rechargeDate != nil { date, err := time.Parse("2006-01-02", *rechargeDate) if err != nil { return fmt.Errorf("无效的日期格式: %w", err) } updates["recharge_date"] = date } if remark != nil { updates["remark"] = *remark } if len(updates) == 0 { return fmt.Errorf("没有需要更新的字段") } updates["updated_at"] = time.Now() return db.Table("gw_model_account_recharge_records"). Where("id = ? AND deleted_at IS NULL", recordID). Updates(updates).Error } // DeleteModelAccountRechargeRecord 删除模型账号充值记录(软删除) func DeleteModelAccountRechargeRecord(id string) error { db := storage.DB if db == nil { return fmt.Errorf("MySQL未初始化") } recordID, err := strconv.ParseUint(id, 10, 64) if err != nil { return fmt.Errorf("无效的记录ID: %w", err) } return db.Table("gw_model_account_recharge_records"). Where("id = ?", recordID). Update("deleted_at", time.Now()).Error } // GetModelConfigAccounts 获取模型配置列表(用于下拉选择) func GetModelConfigAccounts(enabled *bool) ([]map[string]interface{}, error) { db := storage.DB if db == nil { return nil, fmt.Errorf("MySQL未初始化") } q := db.Table("gw_providers"). Select("name, api_type, account, status, priority"). Where("deleted_at IS NULL") if enabled != nil { if *enabled { q = q.Where("status = ?", "active") } else { q = q.Where("status <> ?", "active") } } var accounts []map[string]interface{} rows, err := q.Order("name, account").Rows() if err != nil { return nil, fmt.Errorf("query model accounts: %w", err) } defer rows.Close() cols, _ := rows.Columns() for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } m := make(map[string]interface{}) for i, c := range cols { switch strings.ToLower(c) { case "priority": if val, ok := toInt(vals[i]); ok { m[c] = val } else { m[c] = nil } default: m[c] = toString(vals[i]) } } accounts = append(accounts, m) } return accounts, nil } // GetModelAccountLatestBalance 获取模型账号最新余额 func GetModelAccountLatestBalance(account string) (map[string]interface{}, error) { db := storage.DB if db == nil { return nil, fmt.Errorf("MySQL未初始化") } account = strings.TrimSpace(account) if account == "" { return nil, fmt.Errorf("账号不能为空") } var balance map[string]interface{} err := db.Table("gw_model_account_balances b"). Select(` b.*, gp.name AS provider, gp.api_type `). Joins("LEFT JOIN gw_providers gp ON b.account = gp.account"). Where("b.account = ?", account). Order("b.created_at DESC"). Limit(1). Find(&balance).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return map[string]interface{}{}, nil } return nil, err } // 处理数据类型 result := make(map[string]interface{}) for k, v := range balance { if k == "balance" { if f, ok := toFloat(v); ok { result[k] = f } else { result[k] = v } } else if k == "created_at" { if iso, ok := toTimeISO(v); ok { result[k] = iso } else { result[k] = toString(v) } } else { result[k] = toString(v) } } return result, nil } // GetModelAccountBalances 获取所有模型账号最新余额列表 func GetModelAccountBalances() ([]map[string]interface{}, error) { db := storage.DB if db == nil { return nil, fmt.Errorf("MySQL未初始化") } // 先获取每个账号最新的创建时间 var latestRecords []struct { Account string MaxCreated time.Time } err := db.Table("gw_model_account_balances"). Select("account, MAX(created_at) as max_created"). Group("account"). Find(&latestRecords).Error if err != nil { return nil, fmt.Errorf("查询余额失败: %w", err) } // 为每个账号获取最新余额 balances := make([]map[string]interface{}, 0) for _, lr := range latestRecords { var balance map[string]interface{} err = db.Table("gw_model_account_balances b"). Select(` b.*, gp.name AS provider, gp.api_type `). Joins("LEFT JOIN gw_providers gp ON b.account = gp.account"). Where("b.account = ? AND b.created_at = ?", lr.Account, lr.MaxCreated). Limit(1). Find(&balance).Error if err == nil && len(balance) > 0 { balances = append(balances, balance) } } // 处理数据类型 result := make([]map[string]interface{}, 0) for _, bal := range balances { m := make(map[string]interface{}) for k, v := range bal { if k == "balance" { if f, ok := toFloat(v); ok { m[k] = f } else { m[k] = v } } else if k == "created_at" { if iso, ok := toTimeISO(v); ok { m[k] = iso } else { m[k] = toString(v) } } else { m[k] = toString(v) } } result = append(result, m) } return result, nil } // GetModelAccountBalanceHistory 获取模型账号余额历史 func GetModelAccountBalanceHistory(account string, start, end string) ([]map[string]interface{}, error) { db := storage.DB if db == nil { return nil, fmt.Errorf("MySQL未初始化") } q := db.Table("gw_model_account_balances"). Where("account = ?", account) if start != "" { q = q.Where("created_at >= ?", start) } if end != "" { q = q.Where("created_at <= ?", end) } rows, err := q.Order("created_at DESC").Rows() if err != nil { return nil, fmt.Errorf("query balance history: %w", err) } defer rows.Close() cols, _ := rows.Columns() list := make([]map[string]interface{}, 0) intCols := []string{"id"} floatCols := []string{"balance"} timeCols := []string{"created_at"} for rows.Next() { vals := make([]interface{}, len(cols)) valPtrs := make([]interface{}, len(cols)) for i := range vals { valPtrs[i] = &vals[i] } if err := rows.Scan(valPtrs...); err != nil { return nil, err } m := normalizeGenericInts(cols, vals, intCols, timeCols) // 手动处理float字段 for i, c := range cols { for _, fc := range floatCols { if c == fc { if f, ok := toFloat(vals[i]); ok { m[c] = f } break } } } list = append(list, m) } return list, nil }