2514 lines
66 KiB
Go
2514 lines
66 KiB
Go
package services
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"database/sql"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"goalfymax-admin/internal/config"
|
||
"goalfymax-admin/internal/notifier"
|
||
"goalfymax-admin/internal/storage"
|
||
pkgredis "goalfymax-admin/pkg/redis"
|
||
|
||
mysqldriver "github.com/go-sql-driver/mysql"
|
||
"github.com/google/uuid"
|
||
"github.com/jackc/pgconn"
|
||
"github.com/shopspring/decimal"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type PageResult struct {
|
||
List []map[string]interface{} `json:"list"`
|
||
Total int64 `json:"total"`
|
||
}
|
||
|
||
var (
|
||
usdScalingFactor = decimal.NewFromInt(100000000)
|
||
lowBalanceThreshold = decimal.NewFromInt(30) // 默认值30
|
||
lowBalanceThresholdMutex sync.RWMutex
|
||
ErrMcpUsageJobAlreadyProcessed = errors.New("mcp usage job already processed")
|
||
ErrMcpUsageJobAlreadyInProgress = errors.New("mcp usage job already in progress")
|
||
ErrModelTokenJobAlreadyProcessed = errors.New("model token job already processed")
|
||
ErrModelTokenJobAlreadyInProgress = errors.New("model token job already in progress")
|
||
)
|
||
|
||
// GetLowBalanceThreshold 获取低余额阈值(线程安全)
|
||
func GetLowBalanceThreshold() decimal.Decimal {
|
||
lowBalanceThresholdMutex.RLock()
|
||
defer lowBalanceThresholdMutex.RUnlock()
|
||
return lowBalanceThreshold
|
||
}
|
||
|
||
// RefreshLowBalanceThreshold 从系统配置刷新低余额阈值(定时任务执行前调用)
|
||
func RefreshLowBalanceThreshold() {
|
||
const configKey = "low_balance_threshold"
|
||
const defaultValue = 30
|
||
|
||
// 创建系统配置服务实例
|
||
systemConfigService := NewSystemConfigService(
|
||
storage.NewSystemConfigStorage(),
|
||
nil, // 不使用 logger
|
||
)
|
||
|
||
config, err := systemConfigService.GetByKey(configKey)
|
||
if err != nil {
|
||
// 配置不存在,使用默认值
|
||
lowBalanceThresholdMutex.Lock()
|
||
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
|
||
lowBalanceThresholdMutex.Unlock()
|
||
return
|
||
}
|
||
|
||
// 解析配置值
|
||
valueStr := strings.TrimSpace(config.Value)
|
||
if valueStr == "" {
|
||
// 值为空,使用默认值
|
||
lowBalanceThresholdMutex.Lock()
|
||
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
|
||
lowBalanceThresholdMutex.Unlock()
|
||
return
|
||
}
|
||
|
||
// 尝试解析为整数
|
||
valueInt, err := strconv.ParseInt(valueStr, 10, 64)
|
||
if err != nil {
|
||
// 解析失败,使用默认值
|
||
lowBalanceThresholdMutex.Lock()
|
||
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
|
||
lowBalanceThresholdMutex.Unlock()
|
||
return
|
||
}
|
||
|
||
// 设置配置值
|
||
lowBalanceThresholdMutex.Lock()
|
||
lowBalanceThreshold = decimal.NewFromInt(valueInt)
|
||
lowBalanceThresholdMutex.Unlock()
|
||
}
|
||
|
||
const (
|
||
jobStatusPending = "pending"
|
||
jobStatusSuccess = "success"
|
||
jobStatusFailed = "failed"
|
||
|
||
mcpUsageJobName = "mcp_usage_hourly_settlement"
|
||
modelTokenJobName = "model_token_usage_hourly_settlement"
|
||
)
|
||
|
||
func toString(v interface{}) string {
|
||
switch t := v.(type) {
|
||
case []byte:
|
||
return string(t)
|
||
case string:
|
||
return t
|
||
default:
|
||
return fmt.Sprintf("%v", v)
|
||
}
|
||
}
|
||
|
||
func toInt(v interface{}) (int64, bool) {
|
||
s := toString(v)
|
||
if s == "" {
|
||
return 0, false
|
||
}
|
||
n, err := strconv.ParseInt(s, 10, 64)
|
||
if err != nil {
|
||
return 0, false
|
||
}
|
||
return n, true
|
||
}
|
||
|
||
func toFloat(v interface{}) (float64, bool) {
|
||
s := toString(v)
|
||
if s == "" {
|
||
return 0, false
|
||
}
|
||
f, err := strconv.ParseFloat(s, 64)
|
||
if err != nil {
|
||
return 0, false
|
||
}
|
||
return f, true
|
||
}
|
||
|
||
func toTimeISO(s interface{}) (string, bool) {
|
||
str := toString(s)
|
||
if str == "" {
|
||
return "", false
|
||
}
|
||
// Try common layouts
|
||
layouts := []string{
|
||
time.RFC3339,
|
||
"2006-01-02 15:04:05",
|
||
"2006-01-02T15:04:05",
|
||
}
|
||
for _, l := range layouts {
|
||
if tm, err := time.ParseInLocation(l, str, time.Local); err == nil {
|
||
return tm.Format(time.RFC3339), true
|
||
}
|
||
}
|
||
return str, true
|
||
}
|
||
|
||
func normalizeSandboxRow(cols []string, vals []interface{}) map[string]interface{} {
|
||
m := map[string]interface{}{}
|
||
for i, c := range cols {
|
||
switch c {
|
||
case "duration_minutes", "unit_price_usd", "total_cost_usd":
|
||
if f, ok := toFloat(vals[i]); ok {
|
||
m[c] = f
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
case "total_cost_balance":
|
||
if n, ok := toInt(vals[i]); ok {
|
||
m[c] = n
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
case "created_at", "updated_at", "started_at", "released_at", "last_billed_at":
|
||
if iso, ok := toTimeISO(vals[i]); ok {
|
||
m[c] = iso
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
case "info":
|
||
s := toString(vals[i])
|
||
var obj interface{}
|
||
if err := json.Unmarshal([]byte(s), &obj); err == nil {
|
||
m[c] = obj
|
||
} else {
|
||
m[c] = s
|
||
}
|
||
default:
|
||
m[c] = toString(vals[i])
|
||
}
|
||
}
|
||
return m
|
||
}
|
||
|
||
func normalizeGenericInts(cols []string, vals []interface{}, intCols []string, timeCols []string) map[string]interface{} {
|
||
m := map[string]interface{}{}
|
||
isInt := map[string]struct{}{}
|
||
for _, k := range intCols {
|
||
isInt[k] = struct{}{}
|
||
}
|
||
isTime := map[string]struct{}{}
|
||
for _, k := range timeCols {
|
||
isTime[k] = struct{}{}
|
||
}
|
||
for i, c := range cols {
|
||
if _, ok := isInt[c]; ok {
|
||
if n, ok2 := toInt(vals[i]); ok2 {
|
||
m[c] = n
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
continue
|
||
}
|
||
if _, ok := isTime[c]; ok {
|
||
if iso, ok2 := toTimeISO(vals[i]); ok2 {
|
||
m[c] = iso
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
continue
|
||
}
|
||
m[c] = toString(vals[i])
|
||
}
|
||
return m
|
||
}
|
||
|
||
func parseBool(val interface{}) bool {
|
||
switch v := val.(type) {
|
||
case bool:
|
||
return v
|
||
case *bool:
|
||
if v != nil {
|
||
return *v
|
||
}
|
||
case int, int32, int64:
|
||
if v, ok := toInt(val); ok {
|
||
return v != 0
|
||
}
|
||
case uint, uint32, uint64:
|
||
if v, ok := toInt(val); ok {
|
||
return v != 0
|
||
}
|
||
case string:
|
||
lower := strings.ToLower(strings.TrimSpace(v))
|
||
return lower == "true" || lower == "1" || lower == "t"
|
||
}
|
||
return false
|
||
}
|
||
|
||
type mcpUsageAggregation struct {
|
||
Provider string `gorm:"column:provider"`
|
||
Account sql.NullString `gorm:"column:account"`
|
||
CostSum int64 `gorm:"column:cost_sum"`
|
||
}
|
||
|
||
type modelConfigInfo struct {
|
||
ID uint64
|
||
Provider string
|
||
ModelName string
|
||
PriceRatio float64
|
||
}
|
||
|
||
func applyMcpBalanceChange(tx *gorm.DB, providerID uuid.UUID, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) {
|
||
var latest struct {
|
||
Balance decimal.Decimal `gorm:"column:balance"`
|
||
}
|
||
prev := decimal.Zero
|
||
|
||
err := tx.Table("mcp_account_balances").
|
||
Select("balance").
|
||
Where("provider_id = ?", providerID).
|
||
Order("created_at DESC").
|
||
Take(&latest).Error
|
||
|
||
if err != nil {
|
||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return decimal.Zero, decimal.Zero, err
|
||
}
|
||
} else {
|
||
prev = latest.Balance
|
||
}
|
||
|
||
newBalance := prev.Add(delta)
|
||
|
||
record := map[string]interface{}{
|
||
"provider_id": providerID,
|
||
"balance": newBalance,
|
||
"currency": "USD",
|
||
"remark": remark,
|
||
}
|
||
|
||
if err := tx.Table("mcp_account_balances").Create(record).Error; err != nil {
|
||
return decimal.Zero, decimal.Zero, err
|
||
}
|
||
|
||
return prev, newBalance, nil
|
||
}
|
||
|
||
func findMcpProviderID(pg *gorm.DB, provider, account string) (uuid.UUID, error) {
|
||
var result struct {
|
||
ID uuid.UUID `gorm:"column:id"`
|
||
}
|
||
|
||
query := pg.Table("mcp_providers").Select("id").Where("provider = ?", provider)
|
||
if account != "" {
|
||
query = query.Where("account = ?", account)
|
||
} else {
|
||
query = query.Where("(account IS NULL OR account = '')")
|
||
}
|
||
|
||
err := query.Order("created_at DESC").Take(&result).Error
|
||
if err != nil {
|
||
return uuid.Nil, err
|
||
}
|
||
return result.ID, nil
|
||
}
|
||
|
||
// mcpProviderInfo 存储 provider 信息用于内存查找
|
||
type mcpProviderInfo struct {
|
||
ID uuid.UUID
|
||
Provider string
|
||
Account string // 空字符串表示 account 为空
|
||
FloatingRatio float64
|
||
}
|
||
|
||
type mcpProviderCache struct {
|
||
ID uuid.UUID
|
||
FloatingRatio float64
|
||
}
|
||
|
||
// loadMcpProvidersToMemory 加载所有 mcp_providers 记录到内存
|
||
// 构建 map: key = account, value = provider_id
|
||
func loadMcpProvidersToMemory(pg *gorm.DB) (map[string]mcpProviderCache, error) {
|
||
var providers []mcpProviderInfo
|
||
err := pg.Table("mcp_providers").
|
||
Select("id, provider, COALESCE(account, '') AS account, COALESCE(floating_ratio, 1) AS floating_ratio").
|
||
Order("created_at DESC").
|
||
Find(&providers).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 构建 map: key = account, value = provider_id
|
||
// 如果同一个 account 有多条记录,保留最新的(因为已经按 created_at DESC 排序)
|
||
providerMap := make(map[string]mcpProviderCache)
|
||
for _, p := range providers {
|
||
key := p.Account
|
||
// 如果该 key 已存在,保留最新的(因为已经按 created_at DESC 排序)
|
||
if _, exists := providerMap[key]; !exists {
|
||
ratio := p.FloatingRatio
|
||
if ratio < 0 {
|
||
ratio = 0
|
||
}
|
||
ratio = 1 + ratio
|
||
providerMap[key] = mcpProviderCache{
|
||
ID: p.ID,
|
||
FloatingRatio: ratio,
|
||
}
|
||
}
|
||
}
|
||
|
||
return providerMap, nil
|
||
}
|
||
|
||
// findMcpProviderFromMemory 从内存 map 中查找 provider 信息
|
||
// 直接通过 account 查找,不使用 provider
|
||
func findMcpProviderFromMemory(providerMap map[string]mcpProviderCache, provider, account string) (mcpProviderCache, error) {
|
||
// 直接使用 account 作为 key
|
||
if entry, found := providerMap[account]; found {
|
||
return entry, nil
|
||
}
|
||
|
||
return mcpProviderCache{}, fmt.Errorf("未找到匹配的账号: account=%s", account)
|
||
}
|
||
|
||
func getLastProcessedWindowEnd(pg *gorm.DB, jobName string) (time.Time, error) {
|
||
var result struct {
|
||
WindowEnd time.Time `gorm:"column:window_end"`
|
||
}
|
||
err := pg.Table("mcp_usage_balance_job_runs").
|
||
Select("window_end").
|
||
Where("job_name = ? AND status = ?", jobName, jobStatusSuccess).
|
||
Order("id DESC").
|
||
Limit(1).
|
||
Take(&result).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return time.Time{}, nil // 返回零值表示没有记录
|
||
}
|
||
return time.Time{}, err
|
||
}
|
||
return result.WindowEnd, nil
|
||
}
|
||
|
||
func createMcpUsageJobRun(pg *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) {
|
||
type JobRun struct {
|
||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||
JobName string `gorm:"column:job_name"`
|
||
WindowStart time.Time `gorm:"column:window_start"`
|
||
WindowEnd time.Time `gorm:"column:window_end"`
|
||
Status string `gorm:"column:status"`
|
||
RecordsProcessed int `gorm:"column:records_processed"`
|
||
TotalCost decimal.Decimal `gorm:"column:total_cost"`
|
||
ErrorMessage string `gorm:"column:error_message"`
|
||
CreatedAt time.Time `gorm:"column:created_at"`
|
||
UpdatedAt time.Time `gorm:"column:updated_at"`
|
||
}
|
||
|
||
jobRun := JobRun{
|
||
JobName: jobName,
|
||
WindowStart: windowStart,
|
||
WindowEnd: windowEnd,
|
||
Status: jobStatusPending,
|
||
RecordsProcessed: 0,
|
||
TotalCost: decimal.Zero,
|
||
ErrorMessage: "",
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
|
||
if err := pg.Table("mcp_usage_balance_job_runs").Create(&jobRun).Error; err != nil {
|
||
var pgErr *pgconn.PgError
|
||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||
var existing struct {
|
||
ID int64 `gorm:"column:id"`
|
||
Status string `gorm:"column:status"`
|
||
}
|
||
if lookupErr := pg.Table("mcp_usage_balance_job_runs").
|
||
Select("id, status").
|
||
Where("job_name = ? AND window_start = ?", jobName, windowStart).
|
||
Take(&existing).Error; lookupErr == nil {
|
||
switch existing.Status {
|
||
case jobStatusPending:
|
||
return existing.ID, ErrMcpUsageJobAlreadyInProgress
|
||
case jobStatusSuccess:
|
||
return existing.ID, ErrMcpUsageJobAlreadyProcessed
|
||
default:
|
||
return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status)
|
||
}
|
||
}
|
||
}
|
||
return 0, err
|
||
}
|
||
return jobRun.ID, nil
|
||
}
|
||
|
||
func updateMcpUsageJobRun(pg *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error {
|
||
updateData := map[string]interface{}{
|
||
"status": status,
|
||
"records_processed": records,
|
||
"total_cost": total,
|
||
"error_message": errorMessage,
|
||
"updated_at": time.Now(),
|
||
}
|
||
return pg.Table("mcp_usage_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error
|
||
}
|
||
|
||
func getLastProcessedModelTokenWindowEnd(db *gorm.DB, jobName string) (time.Time, error) {
|
||
var result struct {
|
||
WindowEnd time.Time `gorm:"column:window_end"`
|
||
}
|
||
err := db.Table("model_token_balance_job_runs").
|
||
Select("window_end").
|
||
Where("job_name = ? AND status = ?", jobName, jobStatusSuccess).
|
||
Order("id DESC").
|
||
Limit(1).
|
||
Take(&result).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return time.Time{}, nil
|
||
}
|
||
return time.Time{}, err
|
||
}
|
||
return result.WindowEnd, nil
|
||
}
|
||
|
||
func createModelTokenJobRun(db *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) {
|
||
type jobRun struct {
|
||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||
JobName string `gorm:"column:job_name"`
|
||
WindowStart time.Time `gorm:"column:window_start"`
|
||
WindowEnd time.Time `gorm:"column:window_end"`
|
||
Status string `gorm:"column:status"`
|
||
RecordsProcessed int `gorm:"column:records_processed"`
|
||
TotalCost decimal.Decimal
|
||
ErrorMessage string `gorm:"column:error_message"`
|
||
CreatedAt time.Time `gorm:"column:created_at"`
|
||
UpdatedAt time.Time `gorm:"column:updated_at"`
|
||
}
|
||
|
||
run := jobRun{
|
||
JobName: jobName,
|
||
WindowStart: windowStart,
|
||
WindowEnd: windowEnd,
|
||
Status: jobStatusPending,
|
||
RecordsProcessed: 0,
|
||
TotalCost: decimal.Zero,
|
||
ErrorMessage: "",
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
|
||
if err := db.Table("model_token_balance_job_runs").Create(&run).Error; err != nil {
|
||
var mysqlErr *mysqldriver.MySQLError
|
||
if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 {
|
||
var existing struct {
|
||
ID int64 `gorm:"column:id"`
|
||
Status string `gorm:"column:status"`
|
||
}
|
||
if lookupErr := db.Table("model_token_balance_job_runs").
|
||
Select("id, status").
|
||
Where("job_name = ? AND window_start = ?", jobName, windowStart).
|
||
Order("id DESC").
|
||
Take(&existing).Error; lookupErr == nil {
|
||
switch existing.Status {
|
||
case jobStatusPending:
|
||
return existing.ID, ErrModelTokenJobAlreadyInProgress
|
||
case jobStatusSuccess:
|
||
return existing.ID, ErrModelTokenJobAlreadyProcessed
|
||
default:
|
||
return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status)
|
||
}
|
||
}
|
||
return 0, ErrModelTokenJobAlreadyProcessed
|
||
}
|
||
return 0, err
|
||
}
|
||
|
||
return run.ID, nil
|
||
}
|
||
|
||
func updateModelTokenJobRun(db *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error {
|
||
updateData := map[string]interface{}{
|
||
"status": status,
|
||
"records_processed": records,
|
||
"total_cost": total,
|
||
"error_message": errorMessage,
|
||
"updated_at": time.Now(),
|
||
}
|
||
return db.Table("model_token_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error
|
||
}
|
||
|
||
func loadModelConfigsToMemory(db *gorm.DB) (map[string]modelConfigInfo, error) {
|
||
var configs []struct {
|
||
ID uint64 `gorm:"column:id"`
|
||
Provider string `gorm:"column:provider"`
|
||
ModelName string `gorm:"column:model_name"`
|
||
PriceRatio float64 `gorm:"column:price_ratio"`
|
||
}
|
||
|
||
err := db.Table("gw_model_config_v2").
|
||
Select("id, provider, model_name, price_ratio").
|
||
Where("deleted_at IS NULL").
|
||
Find(&configs).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := make(map[string]modelConfigInfo, len(configs))
|
||
for _, cfg := range configs {
|
||
key := makeModelConfigKey(cfg.Provider, cfg.ModelName)
|
||
if _, exists := result[key]; !exists {
|
||
result[key] = modelConfigInfo{
|
||
ID: cfg.ID,
|
||
Provider: strings.TrimSpace(cfg.Provider),
|
||
ModelName: strings.TrimSpace(cfg.ModelName),
|
||
PriceRatio: cfg.PriceRatio,
|
||
}
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func makeModelConfigKey(provider, model string) string {
|
||
return strings.ToLower(strings.TrimSpace(provider)) + "|" + strings.ToLower(strings.TrimSpace(model))
|
||
}
|
||
|
||
func AdjustMcpAccountBalance(providerID string, newBalance float64, operatorName, remark string) error {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
if newBalance < 0 {
|
||
return fmt.Errorf("余额不能为负数")
|
||
}
|
||
|
||
uuidVal, err := uuid.Parse(providerID)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的provider_id: %w", err)
|
||
}
|
||
|
||
target := decimal.NewFromFloat(newBalance).Round(8)
|
||
current := decimal.Zero
|
||
|
||
var latest struct {
|
||
Balance decimal.Decimal `gorm:"column:balance"`
|
||
}
|
||
err = db.Table("mcp_account_balances").
|
||
Select("balance").
|
||
Where("provider_id = ?", uuidVal).
|
||
Order("created_at DESC").
|
||
Take(&latest).Error
|
||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return fmt.Errorf("获取当前余额失败: %w", err)
|
||
}
|
||
if err == nil {
|
||
current = latest.Balance
|
||
}
|
||
|
||
delta := target.Sub(current)
|
||
if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) {
|
||
// 无需更新,但仍返回成功
|
||
return nil
|
||
}
|
||
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开启事务失败: %w", tx.Error)
|
||
}
|
||
|
||
fullRemark := remark
|
||
if strings.TrimSpace(fullRemark) == "" {
|
||
fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName)
|
||
} else {
|
||
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
|
||
}
|
||
|
||
_, _, err = applyMcpBalanceChange(tx, uuidVal, delta, fullRemark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func CreateMcpAccountBalanceRecord(providerID string, balance float64, operatorName, remark string) error {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
if balance < 0 {
|
||
return fmt.Errorf("余额不能为负数")
|
||
}
|
||
|
||
uuidVal, err := uuid.Parse(providerID)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的provider_id: %w", err)
|
||
}
|
||
|
||
var exists int64
|
||
if err := db.Table("mcp_account_balances").Where("provider_id = ?", uuidVal).Limit(1).Count(&exists).Error; err != nil {
|
||
return fmt.Errorf("查询余额失败: %w", err)
|
||
}
|
||
if exists > 0 {
|
||
return fmt.Errorf("余额记录已存在,请使用调整余额功能")
|
||
}
|
||
|
||
target := decimal.NewFromFloat(balance).Round(8)
|
||
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开启事务失败: %w", tx.Error)
|
||
}
|
||
|
||
fullRemark := remark
|
||
if strings.TrimSpace(fullRemark) == "" {
|
||
fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName)
|
||
} else {
|
||
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
|
||
}
|
||
|
||
_, _, err = applyMcpBalanceChange(tx, uuidVal, target, fullRemark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func AdjustModelAccountBalance(account string, newBalance float64, operatorName, remark string) error {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
if newBalance < 0 {
|
||
return fmt.Errorf("余额不能为负数")
|
||
}
|
||
|
||
target := decimal.NewFromFloat(newBalance).Round(8)
|
||
current := decimal.Zero
|
||
|
||
var latest struct {
|
||
Balance decimal.Decimal `gorm:"column:balance"`
|
||
}
|
||
err := db.Table("gw_model_account_balances").
|
||
Select("balance").
|
||
Where("account = ?", account).
|
||
Order("created_at DESC").
|
||
Take(&latest).Error
|
||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return fmt.Errorf("获取当前余额失败: %w", err)
|
||
}
|
||
if err == nil {
|
||
current = latest.Balance
|
||
}
|
||
|
||
delta := target.Sub(current)
|
||
if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) {
|
||
return nil
|
||
}
|
||
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开启事务失败: %w", tx.Error)
|
||
}
|
||
|
||
fullRemark := remark
|
||
if strings.TrimSpace(fullRemark) == "" {
|
||
fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName)
|
||
} else {
|
||
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
|
||
}
|
||
|
||
_, _, err = applyModelAccountBalanceChange(tx, account, delta, fullRemark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func CreateModelAccountBalanceRecord(account string, balance float64, operatorName, remark string) error {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
if balance < 0 {
|
||
return fmt.Errorf("余额不能为负数")
|
||
}
|
||
|
||
var exists int64
|
||
if err := db.Table("gw_model_account_balances").Where("account = ?", account).Limit(1).Count(&exists).Error; err != nil {
|
||
return fmt.Errorf("查询余额失败: %w", err)
|
||
}
|
||
if exists > 0 {
|
||
return fmt.Errorf("余额记录已存在,请使用调整余额功能")
|
||
}
|
||
|
||
target := decimal.NewFromFloat(balance).Round(8)
|
||
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开启事务失败: %w", tx.Error)
|
||
}
|
||
|
||
fullRemark := remark
|
||
if strings.TrimSpace(fullRemark) == "" {
|
||
fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName)
|
||
} else {
|
||
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
|
||
}
|
||
|
||
_, _, err := applyModelAccountBalanceChange(tx, account, target, fullRemark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func ListSandboxRecords(offset, limit int, userID, projectID, start, end string) (*PageResult, error) {
|
||
var total int64
|
||
q := storage.DB.Table("sb_sandbox_record")
|
||
if userID != "" {
|
||
q = q.Where("user_id = ?", userID)
|
||
}
|
||
if projectID != "" {
|
||
q = q.Where("project_id = ?", projectID)
|
||
}
|
||
if start != "" && end != "" {
|
||
q = q.Where("created_at BETWEEN ? AND ?", start, end)
|
||
}
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count sandbox records: %w", err)
|
||
}
|
||
rows, err := q.Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query sandbox records: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
list = append(list, normalizeSandboxRow(cols, vals))
|
||
}
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
func ListTokenUsages(offset, limit int, userID, projectID, startDay, endDay string) (*PageResult, error) {
|
||
var total int64
|
||
q := storage.DB.Table("gw_token_usages")
|
||
if userID != "" {
|
||
q = q.Where("user_id = ?", userID)
|
||
}
|
||
if projectID != "" {
|
||
q = q.Where("project_id = ?", projectID)
|
||
}
|
||
if startDay != "" && endDay != "" {
|
||
q = q.Where("day BETWEEN ? AND ?", startDay, endDay)
|
||
}
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count token usages: %w", err)
|
||
}
|
||
rows, err := q.Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query token usages: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id", "hour", "cost", "prompt_token", "completion_token", "cache_create_token", "cache_read_token"}
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
list = append(list, normalizeGenericInts(cols, vals, intCols, nil))
|
||
}
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
func ListMCPUsages(offset, limit int, userID, projectID, start, end string) (*PageResult, error) {
|
||
var total int64
|
||
q := storage.DB.Table("mcp_invoke_usages")
|
||
if userID != "" {
|
||
q = q.Where("user_id = ?", userID)
|
||
}
|
||
if projectID != "" {
|
||
q = q.Where("project_id = ?", projectID)
|
||
}
|
||
if start != "" && end != "" {
|
||
q = q.Where("created_at BETWEEN ? AND ?", start, end)
|
||
}
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count mcp usages: %w", err)
|
||
}
|
||
rows, err := q.Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query mcp usages: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id", "hour", "cost", "call_count"}
|
||
timeCols := []string{"created_at", "updated_at"}
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
|
||
}
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
func ListTransactionLogs(offset, limit int, userID, orderID, txType, status, start, end string) (*PageResult, error) {
|
||
var total int64
|
||
q := storage.DB.Table("m_transaction_logs")
|
||
if userID != "" {
|
||
q = q.Where("user_id = ?", userID)
|
||
}
|
||
if orderID != "" {
|
||
q = q.Where("order_id = ?", orderID)
|
||
}
|
||
if txType != "" {
|
||
q = q.Where("type = ?", txType)
|
||
}
|
||
if status != "" {
|
||
q = q.Where("status = ?", status)
|
||
}
|
||
if start != "" && end != "" {
|
||
q = q.Where("created_at BETWEEN ? AND ?", start, end)
|
||
}
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count transaction logs: %w", err)
|
||
}
|
||
rows, err := q.Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query transaction logs: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id", "amount", "balance_before", "balance_after"}
|
||
timeCols := []string{"created_at"}
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
|
||
}
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
// ListPaymentRecords 查询充值支付记录
|
||
func ListPaymentRecords(offset, limit int, userID, orderID, paypalOrderID, status, refundStatus, payerEmail, start, end string) (*PageResult, error) {
|
||
var total int64
|
||
q := storage.DB.Table("m_payment_records")
|
||
if userID != "" {
|
||
q = q.Where("user_id = ?", userID)
|
||
}
|
||
if orderID != "" {
|
||
q = q.Where("order_id = ?", orderID)
|
||
}
|
||
if paypalOrderID != "" {
|
||
q = q.Where("paypal_order_id = ?", paypalOrderID)
|
||
}
|
||
if status != "" {
|
||
q = q.Where("status = ?", status)
|
||
}
|
||
if refundStatus != "" {
|
||
q = q.Where("refund_status = ?", refundStatus)
|
||
}
|
||
if payerEmail != "" {
|
||
q = q.Where("payer_email LIKE ?", "%"+payerEmail+"%")
|
||
}
|
||
if start != "" && end != "" {
|
||
q = q.Where("created_at BETWEEN ? AND ?", start, end)
|
||
}
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count payment records: %w", err)
|
||
}
|
||
rows, err := q.Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query payment records: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id", "amount", "refunded_amount"}
|
||
timeCols := []string{"created_at", "updated_at"}
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
|
||
}
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
// RefundPaymentRecord 退款支付记录
|
||
func RefundPaymentRecord(orderID, paypalCaptureID string, amount *int64) error {
|
||
cfg := config.GetConfig()
|
||
payBaseURL := cfg.Pay.BaseURL
|
||
timeout := time.Duration(cfg.Pay.Timeout) * time.Second
|
||
|
||
// 构建请求体
|
||
reqBody := make(map[string]interface{})
|
||
if orderID != "" {
|
||
reqBody["order_id"] = orderID
|
||
}
|
||
if paypalCaptureID != "" {
|
||
reqBody["paypal_capture_id"] = paypalCaptureID
|
||
}
|
||
if amount != nil {
|
||
reqBody["amount"] = *amount
|
||
}
|
||
|
||
jsonData, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
// 创建HTTP请求
|
||
req, err := http.NewRequest("POST", payBaseURL+"/api/refund", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
// 发送请求
|
||
client := &http.Client{Timeout: timeout}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return fmt.Errorf("发送退款请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("退款接口返回错误: %d, 响应: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 读取订单基础信息(一次查询拿到 user_id 和 amount)
|
||
var rec struct {
|
||
UserID string `json:"user_id"`
|
||
Amount int64 `json:"amount"`
|
||
}
|
||
if err := storage.DB.Table("m_payment_records").Select("user_id, amount").Where("order_id = ?", orderID).Scan(&rec).Error; err != nil {
|
||
return nil // 不阻断主流程
|
||
}
|
||
|
||
// 确定应扣减的退款金额(单位:分)
|
||
var refundCents int64
|
||
if amount != nil {
|
||
refundCents = *amount
|
||
} else {
|
||
refundCents = rec.Amount
|
||
}
|
||
|
||
if refundCents <= 0 {
|
||
return nil
|
||
}
|
||
|
||
// 计算递减值(Redis单位)
|
||
decrBy := refundCents * 1_000_000 // 1 cent = 1e6 in Redis unit
|
||
|
||
// 连接 Redis 并扣减余额
|
||
rclient, err := pkgredis.NewClient(cfg.Redis)
|
||
if err != nil {
|
||
return nil // 忽略Redis失败
|
||
}
|
||
defer rclient.Close()
|
||
|
||
ctx := context.Background()
|
||
key := fmt.Sprintf("GW:QU_%s", rec.UserID)
|
||
_, _ = rclient.Rdb.DecrBy(ctx, key, decrBy).Result()
|
||
|
||
return nil
|
||
}
|
||
|
||
// ListMcpAccountRechargeRecords 查询MCP账号充值记录
|
||
func ListMcpAccountRechargeRecords(offset, limit int, provider, account, start, end string) (*PageResult, error) {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
var total int64
|
||
q := db.Table("mcp_account_recharge_records").
|
||
Select(`
|
||
mcp_account_recharge_records.*,
|
||
mcp_providers.provider,
|
||
mcp_providers.account
|
||
`).
|
||
Joins("LEFT JOIN mcp_providers ON mcp_account_recharge_records.provider_id = mcp_providers.id")
|
||
|
||
if provider != "" {
|
||
q = q.Where("mcp_providers.provider ILIKE ?", "%"+provider+"%")
|
||
}
|
||
if account != "" {
|
||
q = q.Where("mcp_providers.account ILIKE ?", "%"+account+"%")
|
||
}
|
||
if start != "" {
|
||
q = q.Where("mcp_account_recharge_records.recharge_date >= ?", start)
|
||
}
|
||
if end != "" {
|
||
q = q.Where("mcp_account_recharge_records.recharge_date <= ?", end)
|
||
}
|
||
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count mcp account recharge records: %w", err)
|
||
}
|
||
|
||
rows, err := q.Order("mcp_account_recharge_records.created_at DESC").
|
||
Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query mcp account recharge records: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
floatCols := []string{"amount"}
|
||
timeCols := []string{"created_at", "updated_at"}
|
||
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 处理数据类型
|
||
m := map[string]interface{}{}
|
||
isFloat := map[string]struct{}{}
|
||
for _, k := range floatCols {
|
||
isFloat[k] = struct{}{}
|
||
}
|
||
isTime := map[string]struct{}{}
|
||
for _, k := range timeCols {
|
||
isTime[k] = struct{}{}
|
||
}
|
||
|
||
for i, c := range cols {
|
||
if _, ok := isFloat[c]; ok {
|
||
if f, ok2 := toFloat(vals[i]); ok2 {
|
||
m[c] = f
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
continue
|
||
}
|
||
if _, ok := isTime[c]; ok {
|
||
if iso, ok2 := toTimeISO(vals[i]); ok2 {
|
||
m[c] = iso
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
continue
|
||
}
|
||
m[c] = toString(vals[i])
|
||
}
|
||
list = append(list, m)
|
||
}
|
||
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
// CreateMcpAccountRechargeRecord 创建MCP账号充值记录
|
||
func CreateMcpAccountRechargeRecord(providerID string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
// 解析provider_id为UUID
|
||
providerUUID, err := uuid.Parse(providerID)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的provider_id: %w", err)
|
||
}
|
||
|
||
// 解析日期
|
||
date, err := time.Parse("2006-01-02", rechargeDate)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的日期格式: %w", err)
|
||
}
|
||
|
||
// 处理operator_id
|
||
var operatorUUID *uuid.UUID
|
||
if operatorID != nil {
|
||
var opID uuid.UUID
|
||
switch v := operatorID.(type) {
|
||
case string:
|
||
opID, err = uuid.Parse(v)
|
||
if err != nil {
|
||
// 如果不是UUID格式,可能是数字ID,设为nil
|
||
operatorUUID = nil
|
||
} else {
|
||
operatorUUID = &opID
|
||
}
|
||
case uuid.UUID:
|
||
operatorUUID = &v
|
||
default:
|
||
operatorUUID = nil
|
||
}
|
||
}
|
||
|
||
// 开始事务
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开始事务失败: %w", tx.Error)
|
||
}
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
// 1. 创建充值记录
|
||
record := map[string]interface{}{
|
||
"provider_id": providerUUID,
|
||
"amount": amount,
|
||
"currency": "USD",
|
||
"recharge_date": date,
|
||
"operator_id": operatorUUID,
|
||
"operator_name": operatorName,
|
||
"remark": remark,
|
||
}
|
||
|
||
if err := tx.Table("mcp_account_recharge_records").Create(record).Error; err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("创建充值记录失败: %w", err)
|
||
}
|
||
|
||
amountDecimal := decimal.NewFromFloat(amount).Round(8)
|
||
remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName)
|
||
|
||
_, _, err = applyMcpBalanceChange(tx, providerUUID, amountDecimal, remarkText)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("创建余额记录失败: %w", err)
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateMcpAccountRechargeRecord 更新MCP账号充值记录
|
||
func UpdateMcpAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
recordID, err := uuid.Parse(id)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的记录ID: %w", err)
|
||
}
|
||
|
||
updates := map[string]interface{}{}
|
||
if amount != nil {
|
||
updates["amount"] = *amount
|
||
}
|
||
if rechargeDate != nil {
|
||
date, err := time.Parse("2006-01-02", *rechargeDate)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的日期格式: %w", err)
|
||
}
|
||
updates["recharge_date"] = date
|
||
}
|
||
if remark != nil {
|
||
updates["remark"] = *remark
|
||
}
|
||
|
||
if len(updates) == 0 {
|
||
return fmt.Errorf("没有需要更新的字段")
|
||
}
|
||
|
||
updates["updated_at"] = time.Now()
|
||
|
||
return db.Table("mcp_account_recharge_records").
|
||
Where("id = ?", recordID).
|
||
Updates(updates).Error
|
||
}
|
||
|
||
// DeleteMcpAccountRechargeRecord 删除MCP账号充值记录
|
||
func DeleteMcpAccountRechargeRecord(id string) error {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
recordID, err := uuid.Parse(id)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的记录ID: %w", err)
|
||
}
|
||
|
||
return db.Table("mcp_account_recharge_records").
|
||
Where("id = ?", recordID).
|
||
Delete(nil).Error
|
||
}
|
||
|
||
// GetMcpProviderAccounts 获取MCP账号列表(用于下拉选择)
|
||
func GetMcpProviderAccounts(status string, isUsed *bool) ([]map[string]interface{}, error) {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
q := db.Table("mcp_providers").
|
||
Select("id, provider, account, status, is_used")
|
||
|
||
if status != "" {
|
||
q = q.Where("status = ?", status)
|
||
}
|
||
if isUsed != nil {
|
||
q = q.Where("is_used = ?", *isUsed)
|
||
}
|
||
|
||
var list []map[string]interface{}
|
||
rows, err := q.Order("provider, account").Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query mcp providers: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
m := map[string]interface{}{}
|
||
for i, c := range cols {
|
||
m[c] = toString(vals[i])
|
||
}
|
||
list = append(list, m)
|
||
}
|
||
|
||
return list, nil
|
||
}
|
||
|
||
// GetMcpAccountLatestBalance 获取账户最新余额
|
||
func GetMcpAccountLatestBalance(providerID string) (map[string]interface{}, error) {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
providerUUID, err := uuid.Parse(providerID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("无效的provider_id: %w", err)
|
||
}
|
||
|
||
var balance map[string]interface{}
|
||
err = db.Table("mcp_account_balances").
|
||
Select(`
|
||
mcp_account_balances.*,
|
||
mcp_providers.provider,
|
||
mcp_providers.account
|
||
`).
|
||
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
|
||
Where("mcp_account_balances.provider_id = ?", providerUUID).
|
||
Order("mcp_account_balances.created_at DESC").
|
||
Limit(1).
|
||
Find(&balance).Error
|
||
|
||
if err != nil {
|
||
// 如果没有记录,返回空map而不是错误
|
||
return map[string]interface{}{}, nil
|
||
}
|
||
|
||
return balance, nil
|
||
}
|
||
|
||
// GetMcpAccountBalances 获取所有账户最新余额列表
|
||
func GetMcpAccountBalances() ([]map[string]interface{}, error) {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
// 先获取每个账户最新的创建时间
|
||
var latestRecords []struct {
|
||
ProviderID uuid.UUID
|
||
MaxCreated time.Time
|
||
}
|
||
|
||
err := db.Table("mcp_account_balances").
|
||
Select("provider_id, MAX(created_at) as max_created").
|
||
Group("provider_id").
|
||
Find(&latestRecords).Error
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询余额失败: %w", err)
|
||
}
|
||
|
||
// 为每个账户获取最新余额
|
||
balances := make([]map[string]interface{}, 0)
|
||
for _, lr := range latestRecords {
|
||
var balance map[string]interface{}
|
||
err = db.Table("mcp_account_balances").
|
||
Select(`
|
||
mcp_account_balances.*,
|
||
mcp_providers.provider,
|
||
mcp_providers.account
|
||
`).
|
||
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
|
||
Where("mcp_account_balances.provider_id = ? AND mcp_account_balances.created_at = ?", lr.ProviderID, lr.MaxCreated).
|
||
Limit(1).
|
||
Find(&balance).Error
|
||
|
||
if err == nil && len(balance) > 0 {
|
||
balances = append(balances, balance)
|
||
}
|
||
}
|
||
|
||
// 处理数据类型
|
||
result := make([]map[string]interface{}, 0)
|
||
for _, bal := range balances {
|
||
m := make(map[string]interface{})
|
||
for k, v := range bal {
|
||
if k == "balance" {
|
||
if f, ok := toFloat(v); ok {
|
||
m[k] = f
|
||
} else {
|
||
m[k] = v
|
||
}
|
||
} else if k == "created_at" {
|
||
if iso, ok := toTimeISO(v); ok {
|
||
m[k] = iso
|
||
} else {
|
||
m[k] = toString(v)
|
||
}
|
||
} else {
|
||
m[k] = toString(v)
|
||
}
|
||
}
|
||
result = append(result, m)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// RunMcpUsageBalanceJob 汇总MCP调用费用并更新余额
|
||
// 根据上次处理的时间窗口自动计算本次需要处理的时间范围
|
||
func RunMcpUsageBalanceJob() error {
|
||
// 定时任务执行前刷新低余额阈值配置
|
||
RefreshLowBalanceThreshold()
|
||
|
||
mysqlDB := storage.DB
|
||
if mysqlDB == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
pg := storage.GetPG()
|
||
if pg == nil {
|
||
return fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
// 获取上次成功处理的窗口结束时间
|
||
lastWindowEnd, err := getLastProcessedWindowEnd(pg, mcpUsageJobName)
|
||
if err != nil {
|
||
return fmt.Errorf("获取上次处理时间失败: %w", err)
|
||
}
|
||
|
||
// 计算当前时间前一个小时的整点时间
|
||
now := time.Now().UTC()
|
||
currentWindowEnd := now.Truncate(time.Hour)
|
||
|
||
var windowStart, windowEnd time.Time
|
||
if lastWindowEnd.IsZero() {
|
||
// 如果没有历史记录,只处理上一个小时的数据
|
||
windowEnd = currentWindowEnd
|
||
windowStart = windowEnd.Add(-time.Hour)
|
||
} else {
|
||
// 如果有历史记录,从上次处理的时间到当前时间前一个小时
|
||
windowStart = lastWindowEnd
|
||
windowEnd = currentWindowEnd
|
||
}
|
||
|
||
// 如果窗口为空(例如上次处理到10:00,当前也是10:00),直接返回
|
||
if !windowStart.Before(windowEnd) {
|
||
return nil
|
||
}
|
||
|
||
windowStart = windowStart.UTC()
|
||
windowEnd = windowEnd.UTC()
|
||
|
||
// 加载所有 mcp_providers 记录到内存
|
||
providerMap, err := loadMcpProvidersToMemory(pg)
|
||
if err != nil {
|
||
return fmt.Errorf("加载账号信息失败: %w", err)
|
||
}
|
||
|
||
runID, err := createMcpUsageJobRun(pg, mcpUsageJobName, windowStart, windowEnd)
|
||
if err != nil {
|
||
if errors.Is(err, ErrMcpUsageJobAlreadyProcessed) || errors.Is(err, ErrMcpUsageJobAlreadyInProgress) {
|
||
return nil
|
||
}
|
||
return fmt.Errorf("创建任务记录失败: %w", err)
|
||
}
|
||
|
||
var usageRows []mcpUsageAggregation
|
||
err = mysqlDB.Table("mcp_invoke_usages").
|
||
Select("provider, account, SUM(cost) AS cost_sum").
|
||
Where("updated_at >= ? AND updated_at < ?", windowStart, windowEnd).
|
||
Group("provider, account").
|
||
Having("SUM(cost) <> 0").
|
||
Find(&usageRows).Error
|
||
if err != nil {
|
||
updateErr := updateMcpUsageJobRun(pg, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
|
||
if updateErr != nil {
|
||
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err)
|
||
}
|
||
return fmt.Errorf("聚合调用费用失败: %w", err)
|
||
}
|
||
|
||
if len(usageRows) == 0 {
|
||
return updateMcpUsageJobRun(pg, runID, jobStatusSuccess, 0, decimal.Zero, "")
|
||
}
|
||
|
||
var (
|
||
recordsProcessed int
|
||
totalCost = decimal.Zero
|
||
hadFailure bool
|
||
errorMessages []string
|
||
)
|
||
|
||
for _, row := range usageRows {
|
||
provider := strings.TrimSpace(row.Provider)
|
||
account := ""
|
||
if row.Account.Valid {
|
||
account = strings.TrimSpace(row.Account.String)
|
||
}
|
||
|
||
costCents := row.CostSum
|
||
costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor)
|
||
|
||
detailStatus := jobStatusSuccess
|
||
detailError := ""
|
||
|
||
var (
|
||
providerEntry mcpProviderCache
|
||
providerID uuid.UUID
|
||
actualCostUSD = costUSD
|
||
floatingRatio float64 = 1
|
||
)
|
||
if account == "" {
|
||
detailStatus = jobStatusFailed
|
||
detailError = "account为空"
|
||
} else {
|
||
if providerEntry, err = findMcpProviderFromMemory(providerMap, provider, account); err != nil {
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("查找账号失败: %v", err)
|
||
} else {
|
||
providerID = providerEntry.ID
|
||
floatingRatio = providerEntry.FloatingRatio
|
||
}
|
||
}
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
if floatingRatio <= 0 {
|
||
floatingRatio = 1
|
||
}
|
||
if floatingRatio != 1 {
|
||
actualCostUSD = costUSD.Div(decimal.NewFromFloat(floatingRatio))
|
||
}
|
||
actualCostUSD = actualCostUSD.Round(8)
|
||
if actualCostUSD.LessThanOrEqual(decimal.Zero) {
|
||
detailStatus = jobStatusFailed
|
||
detailError = "结算金额为0"
|
||
}
|
||
}
|
||
|
||
var previousBalance decimal.Decimal
|
||
var newBalance decimal.Decimal
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
tx := pg.Begin()
|
||
if tx.Error != nil {
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("开启事务失败: %v", tx.Error)
|
||
} else {
|
||
remark := fmt.Sprintf(
|
||
"自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)",
|
||
windowStart.Format("2006-01-02 15:04"),
|
||
windowEnd.Format("15:04"),
|
||
costUSD.StringFixed(8),
|
||
floatingRatio,
|
||
actualCostUSD.StringFixed(8),
|
||
)
|
||
previousBalance, newBalance, err = applyMcpBalanceChange(tx, providerID, actualCostUSD.Neg(), remark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("扣减余额失败: %v", err)
|
||
} else {
|
||
if commitErr := tx.Commit().Error; commitErr != nil {
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("提交事务失败: %v", commitErr)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
recordsProcessed++
|
||
totalCost = totalCost.Add(actualCostUSD)
|
||
threshold := GetLowBalanceThreshold()
|
||
if newBalance.LessThan(threshold) {
|
||
notifier.NotifyMcpLowBalance(provider, account, newBalance, threshold)
|
||
}
|
||
} else {
|
||
hadFailure = true
|
||
errorMessages = append(errorMessages, fmt.Sprintf("%s/%s: %s", provider, account, detailError))
|
||
}
|
||
|
||
detailData := map[string]interface{}{
|
||
"run_id": runID,
|
||
"provider": provider,
|
||
"account": account,
|
||
"cost_cents": costCents,
|
||
"cost_usd": actualCostUSD,
|
||
"status": detailStatus,
|
||
"created_at": time.Now(),
|
||
"error_message": detailError,
|
||
}
|
||
if providerID != uuid.Nil {
|
||
detailData["provider_id"] = providerID
|
||
}
|
||
if detailStatus == jobStatusSuccess {
|
||
detailData["previous_balance"] = previousBalance
|
||
detailData["new_balance"] = newBalance
|
||
}
|
||
|
||
if err := pg.Table("mcp_usage_balance_job_details").Create(detailData).Error; err != nil {
|
||
hadFailure = true
|
||
errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s: %v", provider, account, err))
|
||
}
|
||
}
|
||
|
||
finalStatus := jobStatusSuccess
|
||
finalError := ""
|
||
if hadFailure {
|
||
finalStatus = jobStatusFailed
|
||
finalError = strings.Join(errorMessages, "; ")
|
||
if len(finalError) > 1000 {
|
||
finalError = finalError[:1000]
|
||
}
|
||
}
|
||
|
||
if err := updateMcpUsageJobRun(pg, runID, finalStatus, recordsProcessed, totalCost, finalError); err != nil {
|
||
return fmt.Errorf("更新任务记录失败: %w", err)
|
||
}
|
||
|
||
if hadFailure {
|
||
return fmt.Errorf("部分账户结算失败: %s", finalError)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// RunModelTokenBalanceJob 汇总模型Token消费并扣减余额
|
||
func RunModelTokenBalanceJob() error {
|
||
// 定时任务执行前刷新低余额阈值配置
|
||
RefreshLowBalanceThreshold()
|
||
|
||
mysqlDB := storage.DB
|
||
if mysqlDB == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
lastWindowEnd, err := getLastProcessedModelTokenWindowEnd(mysqlDB, modelTokenJobName)
|
||
if err != nil {
|
||
return fmt.Errorf("获取上次处理时间失败: %w", err)
|
||
}
|
||
|
||
now := time.Now().UTC()
|
||
currentWindowEnd := now.Truncate(time.Hour)
|
||
|
||
var windowStart, windowEnd time.Time
|
||
if lastWindowEnd.IsZero() {
|
||
windowEnd = currentWindowEnd
|
||
windowStart = windowEnd.Add(-time.Hour)
|
||
} else {
|
||
windowStart = lastWindowEnd
|
||
windowEnd = currentWindowEnd
|
||
}
|
||
|
||
if !windowStart.Before(windowEnd) {
|
||
return nil
|
||
}
|
||
|
||
windowStart = windowStart.UTC()
|
||
windowEnd = windowEnd.UTC()
|
||
|
||
runID, err := createModelTokenJobRun(mysqlDB, modelTokenJobName, windowStart, windowEnd)
|
||
if err != nil {
|
||
if errors.Is(err, ErrModelTokenJobAlreadyProcessed) || errors.Is(err, ErrModelTokenJobAlreadyInProgress) {
|
||
return nil
|
||
}
|
||
return fmt.Errorf("创建任务记录失败: %w", err)
|
||
}
|
||
|
||
configMap, err := loadModelConfigsToMemory(mysqlDB)
|
||
if err != nil {
|
||
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
|
||
if updateErr != nil {
|
||
return fmt.Errorf("加载模型配置失败: %v; 原始错误: %w", updateErr, err)
|
||
}
|
||
return fmt.Errorf("加载模型配置失败: %w", err)
|
||
}
|
||
|
||
startDay := windowStart.Format("2006-01-02")
|
||
startHour := windowStart.Hour()
|
||
endDay := windowEnd.Format("2006-01-02")
|
||
endHour := windowEnd.Hour()
|
||
|
||
query := mysqlDB.Table("gw_token_usages").
|
||
Select("provider, account, model, SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) AS total_cost").
|
||
Where("(day > ?) OR (day = ? AND hour >= ?)", startDay, startDay, startHour).
|
||
Where("(day < ?) OR (day = ? AND hour < ?)", endDay, endDay, endHour).
|
||
Group("provider, account, model").
|
||
Having("SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) <> 0")
|
||
|
||
rows, err := query.Rows()
|
||
if err != nil {
|
||
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
|
||
if updateErr != nil {
|
||
return fmt.Errorf("query error update failed: %v; 原始错误: %w", updateErr, err)
|
||
}
|
||
return fmt.Errorf("聚合Token消费失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
hasRows := false
|
||
var (
|
||
recordsProcessed int
|
||
totalCostUSD = decimal.Zero
|
||
hadFailure bool
|
||
errorMessages []string
|
||
)
|
||
|
||
for rows.Next() {
|
||
hasRows = true
|
||
|
||
var (
|
||
provider string
|
||
account sql.NullString
|
||
model string
|
||
totalCost sql.NullInt64
|
||
)
|
||
|
||
if scanErr := rows.Scan(&provider, &account, &model, &totalCost); scanErr != nil {
|
||
hadFailure = true
|
||
errorMsg := fmt.Sprintf("扫描聚合结果失败: %v", scanErr)
|
||
errorMessages = append(errorMessages, errorMsg)
|
||
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, errorMsg)
|
||
if updateErr != nil {
|
||
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, scanErr)
|
||
}
|
||
return fmt.Errorf("扫描聚合结果失败: %w", scanErr)
|
||
}
|
||
|
||
if !totalCost.Valid {
|
||
continue
|
||
}
|
||
costCents := totalCost.Int64
|
||
if costCents == 0 {
|
||
continue
|
||
}
|
||
|
||
provider = strings.TrimSpace(provider)
|
||
model = strings.TrimSpace(model)
|
||
accountStr := ""
|
||
if account.Valid {
|
||
accountStr = strings.TrimSpace(account.String)
|
||
}
|
||
|
||
costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor)
|
||
detailStatus := jobStatusSuccess
|
||
detailError := ""
|
||
|
||
var (
|
||
modelConfigID uint64
|
||
previousBalance decimal.Decimal
|
||
newBalance decimal.Decimal
|
||
)
|
||
|
||
configInfo, exists := configMap[makeModelConfigKey(provider, model)]
|
||
if !exists {
|
||
detailStatus = jobStatusFailed
|
||
detailError = "未找到匹配的模型配置"
|
||
} else {
|
||
modelConfigID = configInfo.ID
|
||
}
|
||
|
||
priceRatio := configInfo.PriceRatio
|
||
if priceRatio <= 0 {
|
||
priceRatio = 1
|
||
}
|
||
actualCostUSD := costUSD
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
if priceRatio != 1 {
|
||
actualCostUSD = costUSD.Div(decimal.NewFromFloat(priceRatio))
|
||
}
|
||
actualCostUSD = actualCostUSD.Round(8)
|
||
if actualCostUSD.LessThanOrEqual(decimal.Zero) {
|
||
detailStatus = jobStatusFailed
|
||
detailError = "结算金额为0"
|
||
}
|
||
}
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
tx := mysqlDB.Begin()
|
||
if tx.Error != nil {
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("开启事务失败: %v", tx.Error)
|
||
} else {
|
||
remark := fmt.Sprintf(
|
||
"自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)",
|
||
windowStart.Format("2006-01-02 15:04"),
|
||
windowEnd.Format("15:04"),
|
||
costUSD.StringFixed(8),
|
||
priceRatio,
|
||
actualCostUSD.StringFixed(8),
|
||
)
|
||
previousBalance, newBalance, err = applyModelAccountBalanceChange(tx, accountStr, actualCostUSD.Neg(), remark)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("扣减余额失败: %v", err)
|
||
} else if commitErr := tx.Commit().Error; commitErr != nil {
|
||
detailStatus = jobStatusFailed
|
||
detailError = fmt.Sprintf("提交事务失败: %v", commitErr)
|
||
}
|
||
}
|
||
}
|
||
|
||
if detailStatus == jobStatusSuccess {
|
||
recordsProcessed++
|
||
totalCostUSD = totalCostUSD.Add(actualCostUSD)
|
||
threshold := GetLowBalanceThreshold()
|
||
if newBalance.LessThan(threshold) {
|
||
notifier.NotifyModelLowBalance(provider, accountStr, model, newBalance, threshold)
|
||
}
|
||
} else {
|
||
hadFailure = true
|
||
errorMessages = append(errorMessages, fmt.Sprintf("%s/%s/%s: %s", provider, accountStr, model, detailError))
|
||
}
|
||
|
||
detailData := map[string]interface{}{
|
||
"run_id": runID,
|
||
"provider": provider,
|
||
"account": accountStr,
|
||
"model": model,
|
||
"total_cost_cents": costCents,
|
||
"total_cost_usd": actualCostUSD,
|
||
"status": detailStatus,
|
||
"created_at": time.Now(),
|
||
"error_message": detailError,
|
||
}
|
||
if modelConfigID != 0 {
|
||
detailData["model_config_id"] = modelConfigID
|
||
}
|
||
if detailStatus == jobStatusSuccess {
|
||
detailData["previous_balance"] = previousBalance
|
||
detailData["new_balance"] = newBalance
|
||
}
|
||
|
||
if detailErr := mysqlDB.Table("model_token_balance_job_details").Create(detailData).Error; detailErr != nil {
|
||
hadFailure = true
|
||
errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s/%s: %v", provider, accountStr, model, detailErr))
|
||
}
|
||
}
|
||
|
||
if err := rows.Err(); err != nil {
|
||
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, err.Error())
|
||
if updateErr != nil {
|
||
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err)
|
||
}
|
||
return fmt.Errorf("遍历聚合结果失败: %w", err)
|
||
}
|
||
|
||
if !hasRows {
|
||
return updateModelTokenJobRun(mysqlDB, runID, jobStatusSuccess, 0, decimal.Zero, "")
|
||
}
|
||
|
||
finalStatus := jobStatusSuccess
|
||
finalError := ""
|
||
if hadFailure {
|
||
finalStatus = jobStatusFailed
|
||
finalError = strings.Join(errorMessages, "; ")
|
||
if len(finalError) > 1000 {
|
||
finalError = finalError[:1000]
|
||
}
|
||
}
|
||
|
||
if err := updateModelTokenJobRun(mysqlDB, runID, finalStatus, recordsProcessed, totalCostUSD, finalError); err != nil {
|
||
return fmt.Errorf("更新任务记录失败: %w", err)
|
||
}
|
||
|
||
if hadFailure {
|
||
return fmt.Errorf("部分模型结算失败: %s", finalError)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetMcpAccountBalanceHistory 获取账户余额历史记录
|
||
func GetMcpAccountBalanceHistory(providerID string, start, end string) ([]map[string]interface{}, error) {
|
||
db := storage.GetPG()
|
||
if db == nil {
|
||
return nil, fmt.Errorf("PostgreSQL未初始化")
|
||
}
|
||
|
||
providerUUID, err := uuid.Parse(providerID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("无效的provider_id: %w", err)
|
||
}
|
||
|
||
q := db.Table("mcp_account_balances").
|
||
Select(`
|
||
mcp_account_balances.*,
|
||
mcp_providers.provider,
|
||
mcp_providers.account
|
||
`).
|
||
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
|
||
Where("mcp_account_balances.provider_id = ?", providerUUID)
|
||
|
||
if start != "" {
|
||
q = q.Where("mcp_account_balances.created_at >= ?", start)
|
||
}
|
||
if end != "" {
|
||
q = q.Where("mcp_account_balances.created_at <= ?", end)
|
||
}
|
||
|
||
rows, err := q.Order("mcp_account_balances.created_at DESC").Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询余额历史失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
floatCols := []string{"balance"}
|
||
timeCols := []string{"created_at"}
|
||
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
m := map[string]interface{}{}
|
||
isFloat := map[string]struct{}{}
|
||
for _, k := range floatCols {
|
||
isFloat[k] = struct{}{}
|
||
}
|
||
isTime := map[string]struct{}{}
|
||
for _, k := range timeCols {
|
||
isTime[k] = struct{}{}
|
||
}
|
||
|
||
for i, c := range cols {
|
||
if _, ok := isFloat[c]; ok {
|
||
if f, ok2 := toFloat(vals[i]); ok2 {
|
||
m[c] = f
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
continue
|
||
}
|
||
if _, ok := isTime[c]; ok {
|
||
if iso, ok2 := toTimeISO(vals[i]); ok2 {
|
||
m[c] = iso
|
||
} else {
|
||
m[c] = toString(vals[i])
|
||
}
|
||
continue
|
||
}
|
||
m[c] = toString(vals[i])
|
||
}
|
||
list = append(list, m)
|
||
}
|
||
|
||
return list, nil
|
||
}
|
||
|
||
// ========== 模型账号充值记录和余额管理 ==========
|
||
|
||
// applyModelAccountBalanceChange 应用模型账号余额变更(追加式记录)
|
||
func applyModelAccountBalanceChange(tx *gorm.DB, account string, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) {
|
||
var latest struct {
|
||
Balance decimal.Decimal `gorm:"column:balance"`
|
||
}
|
||
prev := decimal.Zero
|
||
|
||
err := tx.Table("gw_model_account_balances").
|
||
Select("balance").
|
||
Where("account = ?", account).
|
||
Order("created_at DESC").
|
||
Take(&latest).Error
|
||
|
||
if err != nil {
|
||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return decimal.Zero, decimal.Zero, err
|
||
}
|
||
} else {
|
||
prev = latest.Balance
|
||
}
|
||
|
||
newBalance := prev.Add(delta)
|
||
|
||
record := map[string]interface{}{
|
||
"account": account,
|
||
"balance": newBalance,
|
||
"currency": "USD",
|
||
"remark": remark,
|
||
"created_at": time.Now(),
|
||
}
|
||
|
||
if err := tx.Table("gw_model_account_balances").Create(record).Error; err != nil {
|
||
return decimal.Zero, decimal.Zero, err
|
||
}
|
||
|
||
return prev, newBalance, nil
|
||
}
|
||
|
||
// ListModelAccountRechargeRecords 查询模型账号充值记录
|
||
func ListModelAccountRechargeRecords(offset, limit int, provider, modelName, start, end string) (*PageResult, error) {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return nil, fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
var total int64
|
||
q := db.Table("gw_model_account_recharge_records r").
|
||
Select(`
|
||
r.*,
|
||
gp.name AS provider
|
||
`).
|
||
Joins("LEFT JOIN gw_providers gp ON r.account = gp.account").
|
||
Where("r.deleted_at IS NULL")
|
||
|
||
if provider != "" {
|
||
q = q.Where("gp.name LIKE ?", "%"+provider+"%")
|
||
}
|
||
if start != "" {
|
||
q = q.Where("r.recharge_date >= ?", start)
|
||
}
|
||
if end != "" {
|
||
q = q.Where("r.recharge_date <= ?", end)
|
||
}
|
||
|
||
if err := q.Count(&total).Error; err != nil {
|
||
return nil, fmt.Errorf("count model account recharge records: %w", err)
|
||
}
|
||
|
||
rows, err := q.Order("r.created_at DESC").
|
||
Offset(offset).Limit(limit).Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query model account recharge records: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id", "operator_id"}
|
||
floatCols := []string{"amount"}
|
||
timeCols := []string{"created_at", "updated_at"}
|
||
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 需要处理float类型
|
||
m := normalizeGenericInts(cols, vals, intCols, timeCols)
|
||
// 手动处理float字段
|
||
for i, c := range cols {
|
||
for _, fc := range floatCols {
|
||
if c == fc {
|
||
if f, ok := toFloat(vals[i]); ok {
|
||
m[c] = f
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
list = append(list, m)
|
||
}
|
||
|
||
return &PageResult{List: list, Total: total}, nil
|
||
}
|
||
|
||
// CreateModelAccountRechargeRecord 创建模型账号充值记录
|
||
func CreateModelAccountRechargeRecord(account string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
account = strings.TrimSpace(account)
|
||
if account == "" {
|
||
return fmt.Errorf("账号不能为空")
|
||
}
|
||
|
||
// 验证账号是否存在
|
||
var count int64
|
||
if err := db.Table("gw_providers").Where("account = ?", account).Count(&count).Error; err != nil {
|
||
return fmt.Errorf("验证账号失败: %w", err)
|
||
}
|
||
if count == 0 {
|
||
return fmt.Errorf("账号不存在: %s", account)
|
||
}
|
||
|
||
// 解析日期
|
||
date, err := time.Parse("2006-01-02", rechargeDate)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的日期格式: %w", err)
|
||
}
|
||
|
||
// 处理operator_id
|
||
var operatorIDUint *uint64
|
||
if operatorID != nil {
|
||
var opID uint64
|
||
switch v := operatorID.(type) {
|
||
case string:
|
||
parsed, err := strconv.ParseUint(v, 10, 64)
|
||
if err != nil {
|
||
operatorIDUint = nil
|
||
} else {
|
||
opID = parsed
|
||
operatorIDUint = &opID
|
||
}
|
||
case uint64:
|
||
operatorIDUint = &v
|
||
case int64:
|
||
u := uint64(v)
|
||
operatorIDUint = &u
|
||
case int:
|
||
u := uint64(v)
|
||
operatorIDUint = &u
|
||
default:
|
||
operatorIDUint = nil
|
||
}
|
||
}
|
||
|
||
// 开始事务
|
||
tx := db.Begin()
|
||
if tx.Error != nil {
|
||
return fmt.Errorf("开始事务失败: %w", tx.Error)
|
||
}
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
// 1. 创建充值记录
|
||
now := time.Now()
|
||
record := map[string]interface{}{
|
||
"account": account,
|
||
"amount": amount,
|
||
"currency": "USD",
|
||
"recharge_date": date,
|
||
"operator_id": operatorIDUint,
|
||
"operator_name": operatorName,
|
||
"remark": remark,
|
||
"created_at": now,
|
||
"updated_at": now,
|
||
}
|
||
|
||
if err := tx.Table("gw_model_account_recharge_records").Create(record).Error; err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("创建充值记录失败: %w", err)
|
||
}
|
||
|
||
amountDecimal := decimal.NewFromFloat(amount).Round(8)
|
||
remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName)
|
||
|
||
_, _, err = applyModelAccountBalanceChange(tx, account, amountDecimal, remarkText)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return fmt.Errorf("创建余额记录失败: %w", err)
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
return fmt.Errorf("提交事务失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateModelAccountRechargeRecord 更新模型账号充值记录
|
||
func UpdateModelAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
recordID, err := strconv.ParseUint(id, 10, 64)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的记录ID: %w", err)
|
||
}
|
||
|
||
updates := map[string]interface{}{}
|
||
if amount != nil {
|
||
updates["amount"] = *amount
|
||
}
|
||
if rechargeDate != nil {
|
||
date, err := time.Parse("2006-01-02", *rechargeDate)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的日期格式: %w", err)
|
||
}
|
||
updates["recharge_date"] = date
|
||
}
|
||
if remark != nil {
|
||
updates["remark"] = *remark
|
||
}
|
||
|
||
if len(updates) == 0 {
|
||
return fmt.Errorf("没有需要更新的字段")
|
||
}
|
||
|
||
updates["updated_at"] = time.Now()
|
||
|
||
return db.Table("gw_model_account_recharge_records").
|
||
Where("id = ? AND deleted_at IS NULL", recordID).
|
||
Updates(updates).Error
|
||
}
|
||
|
||
// DeleteModelAccountRechargeRecord 删除模型账号充值记录(软删除)
|
||
func DeleteModelAccountRechargeRecord(id string) error {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
recordID, err := strconv.ParseUint(id, 10, 64)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的记录ID: %w", err)
|
||
}
|
||
|
||
return db.Table("gw_model_account_recharge_records").
|
||
Where("id = ?", recordID).
|
||
Update("deleted_at", time.Now()).Error
|
||
}
|
||
|
||
// GetModelConfigAccounts 获取模型配置列表(用于下拉选择)
|
||
func GetModelConfigAccounts(enabled *bool) ([]map[string]interface{}, error) {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return nil, fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
q := db.Table("gw_providers").
|
||
Select("name, api_type, account, status, priority").
|
||
Where("deleted_at IS NULL")
|
||
|
||
if enabled != nil {
|
||
if *enabled {
|
||
q = q.Where("status = ?", "active")
|
||
} else {
|
||
q = q.Where("status <> ?", "active")
|
||
}
|
||
}
|
||
|
||
var accounts []map[string]interface{}
|
||
rows, err := q.Order("name, account").Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query model accounts: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
m := make(map[string]interface{})
|
||
for i, c := range cols {
|
||
switch strings.ToLower(c) {
|
||
case "priority":
|
||
if val, ok := toInt(vals[i]); ok {
|
||
m[c] = val
|
||
} else {
|
||
m[c] = nil
|
||
}
|
||
default:
|
||
m[c] = toString(vals[i])
|
||
}
|
||
}
|
||
accounts = append(accounts, m)
|
||
}
|
||
|
||
return accounts, nil
|
||
}
|
||
|
||
// GetModelAccountLatestBalance 获取模型账号最新余额
|
||
func GetModelAccountLatestBalance(account string) (map[string]interface{}, error) {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return nil, fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
account = strings.TrimSpace(account)
|
||
if account == "" {
|
||
return nil, fmt.Errorf("账号不能为空")
|
||
}
|
||
|
||
var balance map[string]interface{}
|
||
err := db.Table("gw_model_account_balances b").
|
||
Select(`
|
||
b.*,
|
||
gp.name AS provider,
|
||
gp.api_type
|
||
`).
|
||
Joins("LEFT JOIN gw_providers gp ON b.account = gp.account").
|
||
Where("b.account = ?", account).
|
||
Order("b.created_at DESC").
|
||
Limit(1).
|
||
Find(&balance).Error
|
||
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return map[string]interface{}{}, nil
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 处理数据类型
|
||
result := make(map[string]interface{})
|
||
for k, v := range balance {
|
||
if k == "balance" {
|
||
if f, ok := toFloat(v); ok {
|
||
result[k] = f
|
||
} else {
|
||
result[k] = v
|
||
}
|
||
} else if k == "created_at" {
|
||
if iso, ok := toTimeISO(v); ok {
|
||
result[k] = iso
|
||
} else {
|
||
result[k] = toString(v)
|
||
}
|
||
} else {
|
||
result[k] = toString(v)
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetModelAccountBalances 获取所有模型账号最新余额列表
|
||
func GetModelAccountBalances() ([]map[string]interface{}, error) {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return nil, fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
// 先获取每个账号最新的创建时间
|
||
var latestRecords []struct {
|
||
Account string
|
||
MaxCreated time.Time
|
||
}
|
||
|
||
err := db.Table("gw_model_account_balances").
|
||
Select("account, MAX(created_at) as max_created").
|
||
Group("account").
|
||
Find(&latestRecords).Error
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询余额失败: %w", err)
|
||
}
|
||
|
||
// 为每个账号获取最新余额
|
||
balances := make([]map[string]interface{}, 0)
|
||
for _, lr := range latestRecords {
|
||
var balance map[string]interface{}
|
||
err = db.Table("gw_model_account_balances b").
|
||
Select(`
|
||
b.*,
|
||
gp.name AS provider,
|
||
gp.api_type
|
||
`).
|
||
Joins("LEFT JOIN gw_providers gp ON b.account = gp.account").
|
||
Where("b.account = ? AND b.created_at = ?", lr.Account, lr.MaxCreated).
|
||
Limit(1).
|
||
Find(&balance).Error
|
||
|
||
if err == nil && len(balance) > 0 {
|
||
balances = append(balances, balance)
|
||
}
|
||
}
|
||
|
||
// 处理数据类型
|
||
result := make([]map[string]interface{}, 0)
|
||
for _, bal := range balances {
|
||
m := make(map[string]interface{})
|
||
for k, v := range bal {
|
||
if k == "balance" {
|
||
if f, ok := toFloat(v); ok {
|
||
m[k] = f
|
||
} else {
|
||
m[k] = v
|
||
}
|
||
} else if k == "created_at" {
|
||
if iso, ok := toTimeISO(v); ok {
|
||
m[k] = iso
|
||
} else {
|
||
m[k] = toString(v)
|
||
}
|
||
} else {
|
||
m[k] = toString(v)
|
||
}
|
||
}
|
||
result = append(result, m)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetModelAccountBalanceHistory 获取模型账号余额历史
|
||
func GetModelAccountBalanceHistory(account string, start, end string) ([]map[string]interface{}, error) {
|
||
db := storage.DB
|
||
if db == nil {
|
||
return nil, fmt.Errorf("MySQL未初始化")
|
||
}
|
||
|
||
q := db.Table("gw_model_account_balances").
|
||
Where("account = ?", account)
|
||
|
||
if start != "" {
|
||
q = q.Where("created_at >= ?", start)
|
||
}
|
||
if end != "" {
|
||
q = q.Where("created_at <= ?", end)
|
||
}
|
||
|
||
rows, err := q.Order("created_at DESC").Rows()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query balance history: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
cols, _ := rows.Columns()
|
||
list := make([]map[string]interface{}, 0)
|
||
intCols := []string{"id"}
|
||
floatCols := []string{"balance"}
|
||
timeCols := []string{"created_at"}
|
||
|
||
for rows.Next() {
|
||
vals := make([]interface{}, len(cols))
|
||
valPtrs := make([]interface{}, len(cols))
|
||
for i := range vals {
|
||
valPtrs[i] = &vals[i]
|
||
}
|
||
if err := rows.Scan(valPtrs...); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
m := normalizeGenericInts(cols, vals, intCols, timeCols)
|
||
// 手动处理float字段
|
||
for i, c := range cols {
|
||
for _, fc := range floatCols {
|
||
if c == fc {
|
||
if f, ok := toFloat(vals[i]); ok {
|
||
m[c] = f
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
list = append(list, m)
|
||
}
|
||
|
||
return list, nil
|
||
}
|