Files
goalfylearning-admin/internal/services/finance_service.go

2514 lines
66 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"goalfymax-admin/internal/config"
"goalfymax-admin/internal/notifier"
"goalfymax-admin/internal/storage"
pkgredis "goalfymax-admin/pkg/redis"
mysqldriver "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/jackc/pgconn"
"github.com/shopspring/decimal"
"gorm.io/gorm"
)
type PageResult struct {
List []map[string]interface{} `json:"list"`
Total int64 `json:"total"`
}
var (
usdScalingFactor = decimal.NewFromInt(100000000)
lowBalanceThreshold = decimal.NewFromInt(30) // 默认值30
lowBalanceThresholdMutex sync.RWMutex
ErrMcpUsageJobAlreadyProcessed = errors.New("mcp usage job already processed")
ErrMcpUsageJobAlreadyInProgress = errors.New("mcp usage job already in progress")
ErrModelTokenJobAlreadyProcessed = errors.New("model token job already processed")
ErrModelTokenJobAlreadyInProgress = errors.New("model token job already in progress")
)
// GetLowBalanceThreshold 获取低余额阈值(线程安全)
func GetLowBalanceThreshold() decimal.Decimal {
lowBalanceThresholdMutex.RLock()
defer lowBalanceThresholdMutex.RUnlock()
return lowBalanceThreshold
}
// RefreshLowBalanceThreshold 从系统配置刷新低余额阈值(定时任务执行前调用)
func RefreshLowBalanceThreshold() {
const configKey = "low_balance_threshold"
const defaultValue = 30
// 创建系统配置服务实例
systemConfigService := NewSystemConfigService(
storage.NewSystemConfigStorage(),
nil, // 不使用 logger
)
config, err := systemConfigService.GetByKey(configKey)
if err != nil {
// 配置不存在,使用默认值
lowBalanceThresholdMutex.Lock()
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
lowBalanceThresholdMutex.Unlock()
return
}
// 解析配置值
valueStr := strings.TrimSpace(config.Value)
if valueStr == "" {
// 值为空,使用默认值
lowBalanceThresholdMutex.Lock()
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
lowBalanceThresholdMutex.Unlock()
return
}
// 尝试解析为整数
valueInt, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
// 解析失败,使用默认值
lowBalanceThresholdMutex.Lock()
lowBalanceThreshold = decimal.NewFromInt(defaultValue)
lowBalanceThresholdMutex.Unlock()
return
}
// 设置配置值
lowBalanceThresholdMutex.Lock()
lowBalanceThreshold = decimal.NewFromInt(valueInt)
lowBalanceThresholdMutex.Unlock()
}
const (
jobStatusPending = "pending"
jobStatusSuccess = "success"
jobStatusFailed = "failed"
mcpUsageJobName = "mcp_usage_hourly_settlement"
modelTokenJobName = "model_token_usage_hourly_settlement"
)
func toString(v interface{}) string {
switch t := v.(type) {
case []byte:
return string(t)
case string:
return t
default:
return fmt.Sprintf("%v", v)
}
}
func toInt(v interface{}) (int64, bool) {
s := toString(v)
if s == "" {
return 0, false
}
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, false
}
return n, true
}
func toFloat(v interface{}) (float64, bool) {
s := toString(v)
if s == "" {
return 0, false
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, false
}
return f, true
}
func toTimeISO(s interface{}) (string, bool) {
str := toString(s)
if str == "" {
return "", false
}
// Try common layouts
layouts := []string{
time.RFC3339,
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
}
for _, l := range layouts {
if tm, err := time.ParseInLocation(l, str, time.Local); err == nil {
return tm.Format(time.RFC3339), true
}
}
return str, true
}
func normalizeSandboxRow(cols []string, vals []interface{}) map[string]interface{} {
m := map[string]interface{}{}
for i, c := range cols {
switch c {
case "duration_minutes", "unit_price_usd", "total_cost_usd":
if f, ok := toFloat(vals[i]); ok {
m[c] = f
} else {
m[c] = nil
}
case "total_cost_balance":
if n, ok := toInt(vals[i]); ok {
m[c] = n
} else {
m[c] = nil
}
case "created_at", "updated_at", "started_at", "released_at", "last_billed_at":
if iso, ok := toTimeISO(vals[i]); ok {
m[c] = iso
} else {
m[c] = nil
}
case "info":
s := toString(vals[i])
var obj interface{}
if err := json.Unmarshal([]byte(s), &obj); err == nil {
m[c] = obj
} else {
m[c] = s
}
default:
m[c] = toString(vals[i])
}
}
return m
}
func normalizeGenericInts(cols []string, vals []interface{}, intCols []string, timeCols []string) map[string]interface{} {
m := map[string]interface{}{}
isInt := map[string]struct{}{}
for _, k := range intCols {
isInt[k] = struct{}{}
}
isTime := map[string]struct{}{}
for _, k := range timeCols {
isTime[k] = struct{}{}
}
for i, c := range cols {
if _, ok := isInt[c]; ok {
if n, ok2 := toInt(vals[i]); ok2 {
m[c] = n
} else {
m[c] = nil
}
continue
}
if _, ok := isTime[c]; ok {
if iso, ok2 := toTimeISO(vals[i]); ok2 {
m[c] = iso
} else {
m[c] = nil
}
continue
}
m[c] = toString(vals[i])
}
return m
}
func parseBool(val interface{}) bool {
switch v := val.(type) {
case bool:
return v
case *bool:
if v != nil {
return *v
}
case int, int32, int64:
if v, ok := toInt(val); ok {
return v != 0
}
case uint, uint32, uint64:
if v, ok := toInt(val); ok {
return v != 0
}
case string:
lower := strings.ToLower(strings.TrimSpace(v))
return lower == "true" || lower == "1" || lower == "t"
}
return false
}
type mcpUsageAggregation struct {
Provider string `gorm:"column:provider"`
Account sql.NullString `gorm:"column:account"`
CostSum int64 `gorm:"column:cost_sum"`
}
type modelConfigInfo struct {
ID uint64
Provider string
ModelName string
PriceRatio float64
}
func applyMcpBalanceChange(tx *gorm.DB, providerID uuid.UUID, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) {
var latest struct {
Balance decimal.Decimal `gorm:"column:balance"`
}
prev := decimal.Zero
err := tx.Table("mcp_account_balances").
Select("balance").
Where("provider_id = ?", providerID).
Order("created_at DESC").
Take(&latest).Error
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return decimal.Zero, decimal.Zero, err
}
} else {
prev = latest.Balance
}
newBalance := prev.Add(delta)
record := map[string]interface{}{
"provider_id": providerID,
"balance": newBalance,
"currency": "USD",
"remark": remark,
}
if err := tx.Table("mcp_account_balances").Create(record).Error; err != nil {
return decimal.Zero, decimal.Zero, err
}
return prev, newBalance, nil
}
func findMcpProviderID(pg *gorm.DB, provider, account string) (uuid.UUID, error) {
var result struct {
ID uuid.UUID `gorm:"column:id"`
}
query := pg.Table("mcp_providers").Select("id").Where("provider = ?", provider)
if account != "" {
query = query.Where("account = ?", account)
} else {
query = query.Where("(account IS NULL OR account = '')")
}
err := query.Order("created_at DESC").Take(&result).Error
if err != nil {
return uuid.Nil, err
}
return result.ID, nil
}
// mcpProviderInfo 存储 provider 信息用于内存查找
type mcpProviderInfo struct {
ID uuid.UUID
Provider string
Account string // 空字符串表示 account 为空
FloatingRatio float64
}
type mcpProviderCache struct {
ID uuid.UUID
FloatingRatio float64
}
// loadMcpProvidersToMemory 加载所有 mcp_providers 记录到内存
// 构建 map: key = account, value = provider_id
func loadMcpProvidersToMemory(pg *gorm.DB) (map[string]mcpProviderCache, error) {
var providers []mcpProviderInfo
err := pg.Table("mcp_providers").
Select("id, provider, COALESCE(account, '') AS account, COALESCE(floating_ratio, 1) AS floating_ratio").
Order("created_at DESC").
Find(&providers).Error
if err != nil {
return nil, err
}
// 构建 map: key = account, value = provider_id
// 如果同一个 account 有多条记录,保留最新的(因为已经按 created_at DESC 排序)
providerMap := make(map[string]mcpProviderCache)
for _, p := range providers {
key := p.Account
// 如果该 key 已存在,保留最新的(因为已经按 created_at DESC 排序)
if _, exists := providerMap[key]; !exists {
ratio := p.FloatingRatio
if ratio < 0 {
ratio = 0
}
ratio = 1 + ratio
providerMap[key] = mcpProviderCache{
ID: p.ID,
FloatingRatio: ratio,
}
}
}
return providerMap, nil
}
// findMcpProviderFromMemory 从内存 map 中查找 provider 信息
// 直接通过 account 查找,不使用 provider
func findMcpProviderFromMemory(providerMap map[string]mcpProviderCache, provider, account string) (mcpProviderCache, error) {
// 直接使用 account 作为 key
if entry, found := providerMap[account]; found {
return entry, nil
}
return mcpProviderCache{}, fmt.Errorf("未找到匹配的账号: account=%s", account)
}
func getLastProcessedWindowEnd(pg *gorm.DB, jobName string) (time.Time, error) {
var result struct {
WindowEnd time.Time `gorm:"column:window_end"`
}
err := pg.Table("mcp_usage_balance_job_runs").
Select("window_end").
Where("job_name = ? AND status = ?", jobName, jobStatusSuccess).
Order("id DESC").
Limit(1).
Take(&result).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return time.Time{}, nil // 返回零值表示没有记录
}
return time.Time{}, err
}
return result.WindowEnd, nil
}
func createMcpUsageJobRun(pg *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) {
type JobRun struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
JobName string `gorm:"column:job_name"`
WindowStart time.Time `gorm:"column:window_start"`
WindowEnd time.Time `gorm:"column:window_end"`
Status string `gorm:"column:status"`
RecordsProcessed int `gorm:"column:records_processed"`
TotalCost decimal.Decimal `gorm:"column:total_cost"`
ErrorMessage string `gorm:"column:error_message"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
}
jobRun := JobRun{
JobName: jobName,
WindowStart: windowStart,
WindowEnd: windowEnd,
Status: jobStatusPending,
RecordsProcessed: 0,
TotalCost: decimal.Zero,
ErrorMessage: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := pg.Table("mcp_usage_balance_job_runs").Create(&jobRun).Error; err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
var existing struct {
ID int64 `gorm:"column:id"`
Status string `gorm:"column:status"`
}
if lookupErr := pg.Table("mcp_usage_balance_job_runs").
Select("id, status").
Where("job_name = ? AND window_start = ?", jobName, windowStart).
Take(&existing).Error; lookupErr == nil {
switch existing.Status {
case jobStatusPending:
return existing.ID, ErrMcpUsageJobAlreadyInProgress
case jobStatusSuccess:
return existing.ID, ErrMcpUsageJobAlreadyProcessed
default:
return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status)
}
}
}
return 0, err
}
return jobRun.ID, nil
}
func updateMcpUsageJobRun(pg *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error {
updateData := map[string]interface{}{
"status": status,
"records_processed": records,
"total_cost": total,
"error_message": errorMessage,
"updated_at": time.Now(),
}
return pg.Table("mcp_usage_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error
}
func getLastProcessedModelTokenWindowEnd(db *gorm.DB, jobName string) (time.Time, error) {
var result struct {
WindowEnd time.Time `gorm:"column:window_end"`
}
err := db.Table("model_token_balance_job_runs").
Select("window_end").
Where("job_name = ? AND status = ?", jobName, jobStatusSuccess).
Order("id DESC").
Limit(1).
Take(&result).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return time.Time{}, nil
}
return time.Time{}, err
}
return result.WindowEnd, nil
}
func createModelTokenJobRun(db *gorm.DB, jobName string, windowStart, windowEnd time.Time) (int64, error) {
type jobRun struct {
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
JobName string `gorm:"column:job_name"`
WindowStart time.Time `gorm:"column:window_start"`
WindowEnd time.Time `gorm:"column:window_end"`
Status string `gorm:"column:status"`
RecordsProcessed int `gorm:"column:records_processed"`
TotalCost decimal.Decimal
ErrorMessage string `gorm:"column:error_message"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
}
run := jobRun{
JobName: jobName,
WindowStart: windowStart,
WindowEnd: windowEnd,
Status: jobStatusPending,
RecordsProcessed: 0,
TotalCost: decimal.Zero,
ErrorMessage: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := db.Table("model_token_balance_job_runs").Create(&run).Error; err != nil {
var mysqlErr *mysqldriver.MySQLError
if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 {
var existing struct {
ID int64 `gorm:"column:id"`
Status string `gorm:"column:status"`
}
if lookupErr := db.Table("model_token_balance_job_runs").
Select("id, status").
Where("job_name = ? AND window_start = ?", jobName, windowStart).
Order("id DESC").
Take(&existing).Error; lookupErr == nil {
switch existing.Status {
case jobStatusPending:
return existing.ID, ErrModelTokenJobAlreadyInProgress
case jobStatusSuccess:
return existing.ID, ErrModelTokenJobAlreadyProcessed
default:
return existing.ID, fmt.Errorf("job already exists with status %s", existing.Status)
}
}
return 0, ErrModelTokenJobAlreadyProcessed
}
return 0, err
}
return run.ID, nil
}
func updateModelTokenJobRun(db *gorm.DB, runID int64, status string, records int, total decimal.Decimal, errorMessage string) error {
updateData := map[string]interface{}{
"status": status,
"records_processed": records,
"total_cost": total,
"error_message": errorMessage,
"updated_at": time.Now(),
}
return db.Table("model_token_balance_job_runs").Where("id = ?", runID).Updates(updateData).Error
}
func loadModelConfigsToMemory(db *gorm.DB) (map[string]modelConfigInfo, error) {
var configs []struct {
ID uint64 `gorm:"column:id"`
Provider string `gorm:"column:provider"`
ModelName string `gorm:"column:model_name"`
PriceRatio float64 `gorm:"column:price_ratio"`
}
err := db.Table("gw_model_config_v2").
Select("id, provider, model_name, price_ratio").
Where("deleted_at IS NULL").
Find(&configs).Error
if err != nil {
return nil, err
}
result := make(map[string]modelConfigInfo, len(configs))
for _, cfg := range configs {
key := makeModelConfigKey(cfg.Provider, cfg.ModelName)
if _, exists := result[key]; !exists {
result[key] = modelConfigInfo{
ID: cfg.ID,
Provider: strings.TrimSpace(cfg.Provider),
ModelName: strings.TrimSpace(cfg.ModelName),
PriceRatio: cfg.PriceRatio,
}
}
}
return result, nil
}
func makeModelConfigKey(provider, model string) string {
return strings.ToLower(strings.TrimSpace(provider)) + "|" + strings.ToLower(strings.TrimSpace(model))
}
func AdjustMcpAccountBalance(providerID string, newBalance float64, operatorName, remark string) error {
db := storage.GetPG()
if db == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
if newBalance < 0 {
return fmt.Errorf("余额不能为负数")
}
uuidVal, err := uuid.Parse(providerID)
if err != nil {
return fmt.Errorf("无效的provider_id: %w", err)
}
target := decimal.NewFromFloat(newBalance).Round(8)
current := decimal.Zero
var latest struct {
Balance decimal.Decimal `gorm:"column:balance"`
}
err = db.Table("mcp_account_balances").
Select("balance").
Where("provider_id = ?", uuidVal).
Order("created_at DESC").
Take(&latest).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("获取当前余额失败: %w", err)
}
if err == nil {
current = latest.Balance
}
delta := target.Sub(current)
if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) {
// 无需更新,但仍返回成功
return nil
}
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开启事务失败: %w", tx.Error)
}
fullRemark := remark
if strings.TrimSpace(fullRemark) == "" {
fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName)
} else {
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
}
_, _, err = applyMcpBalanceChange(tx, uuidVal, delta, fullRemark)
if err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
func CreateMcpAccountBalanceRecord(providerID string, balance float64, operatorName, remark string) error {
db := storage.GetPG()
if db == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
if balance < 0 {
return fmt.Errorf("余额不能为负数")
}
uuidVal, err := uuid.Parse(providerID)
if err != nil {
return fmt.Errorf("无效的provider_id: %w", err)
}
var exists int64
if err := db.Table("mcp_account_balances").Where("provider_id = ?", uuidVal).Limit(1).Count(&exists).Error; err != nil {
return fmt.Errorf("查询余额失败: %w", err)
}
if exists > 0 {
return fmt.Errorf("余额记录已存在,请使用调整余额功能")
}
target := decimal.NewFromFloat(balance).Round(8)
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开启事务失败: %w", tx.Error)
}
fullRemark := remark
if strings.TrimSpace(fullRemark) == "" {
fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName)
} else {
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
}
_, _, err = applyMcpBalanceChange(tx, uuidVal, target, fullRemark)
if err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
func AdjustModelAccountBalance(account string, newBalance float64, operatorName, remark string) error {
db := storage.DB
if db == nil {
return fmt.Errorf("MySQL未初始化")
}
if newBalance < 0 {
return fmt.Errorf("余额不能为负数")
}
target := decimal.NewFromFloat(newBalance).Round(8)
current := decimal.Zero
var latest struct {
Balance decimal.Decimal `gorm:"column:balance"`
}
err := db.Table("gw_model_account_balances").
Select("balance").
Where("account = ?", account).
Order("created_at DESC").
Take(&latest).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("获取当前余额失败: %w", err)
}
if err == nil {
current = latest.Balance
}
delta := target.Sub(current)
if delta.Abs().LessThan(decimal.NewFromFloat(1e-8)) {
return nil
}
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开启事务失败: %w", tx.Error)
}
fullRemark := remark
if strings.TrimSpace(fullRemark) == "" {
fullRemark = fmt.Sprintf("手动调整余额至 $%s by %s", target.StringFixed(8), operatorName)
} else {
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
}
_, _, err = applyModelAccountBalanceChange(tx, account, delta, fullRemark)
if err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
func CreateModelAccountBalanceRecord(account string, balance float64, operatorName, remark string) error {
db := storage.DB
if db == nil {
return fmt.Errorf("MySQL未初始化")
}
if balance < 0 {
return fmt.Errorf("余额不能为负数")
}
var exists int64
if err := db.Table("gw_model_account_balances").Where("account = ?", account).Limit(1).Count(&exists).Error; err != nil {
return fmt.Errorf("查询余额失败: %w", err)
}
if exists > 0 {
return fmt.Errorf("余额记录已存在,请使用调整余额功能")
}
target := decimal.NewFromFloat(balance).Round(8)
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开启事务失败: %w", tx.Error)
}
fullRemark := remark
if strings.TrimSpace(fullRemark) == "" {
fullRemark = fmt.Sprintf("新增余额 $%s by %s", target.StringFixed(8), operatorName)
} else {
fullRemark = fmt.Sprintf("%s (by %s)", remark, operatorName)
}
_, _, err := applyModelAccountBalanceChange(tx, account, target, fullRemark)
if err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
func ListSandboxRecords(offset, limit int, userID, projectID, start, end string) (*PageResult, error) {
var total int64
q := storage.DB.Table("sb_sandbox_record")
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if projectID != "" {
q = q.Where("project_id = ?", projectID)
}
if start != "" && end != "" {
q = q.Where("created_at BETWEEN ? AND ?", start, end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count sandbox records: %w", err)
}
rows, err := q.Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query sandbox records: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
list = append(list, normalizeSandboxRow(cols, vals))
}
return &PageResult{List: list, Total: total}, nil
}
func ListTokenUsages(offset, limit int, userID, projectID, startDay, endDay string) (*PageResult, error) {
var total int64
q := storage.DB.Table("gw_token_usages")
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if projectID != "" {
q = q.Where("project_id = ?", projectID)
}
if startDay != "" && endDay != "" {
q = q.Where("day BETWEEN ? AND ?", startDay, endDay)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count token usages: %w", err)
}
rows, err := q.Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query token usages: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id", "hour", "cost", "prompt_token", "completion_token", "cache_create_token", "cache_read_token"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
list = append(list, normalizeGenericInts(cols, vals, intCols, nil))
}
return &PageResult{List: list, Total: total}, nil
}
func ListMCPUsages(offset, limit int, userID, projectID, start, end string) (*PageResult, error) {
var total int64
q := storage.DB.Table("mcp_invoke_usages")
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if projectID != "" {
q = q.Where("project_id = ?", projectID)
}
if start != "" && end != "" {
q = q.Where("created_at BETWEEN ? AND ?", start, end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count mcp usages: %w", err)
}
rows, err := q.Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query mcp usages: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id", "hour", "cost", "call_count"}
timeCols := []string{"created_at", "updated_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
}
return &PageResult{List: list, Total: total}, nil
}
func ListTransactionLogs(offset, limit int, userID, orderID, txType, status, start, end string) (*PageResult, error) {
var total int64
q := storage.DB.Table("m_transaction_logs")
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if orderID != "" {
q = q.Where("order_id = ?", orderID)
}
if txType != "" {
q = q.Where("type = ?", txType)
}
if status != "" {
q = q.Where("status = ?", status)
}
if start != "" && end != "" {
q = q.Where("created_at BETWEEN ? AND ?", start, end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count transaction logs: %w", err)
}
rows, err := q.Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query transaction logs: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id", "amount", "balance_before", "balance_after"}
timeCols := []string{"created_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
}
return &PageResult{List: list, Total: total}, nil
}
// ListPaymentRecords 查询充值支付记录
func ListPaymentRecords(offset, limit int, userID, orderID, paypalOrderID, status, refundStatus, payerEmail, start, end string) (*PageResult, error) {
var total int64
q := storage.DB.Table("m_payment_records")
if userID != "" {
q = q.Where("user_id = ?", userID)
}
if orderID != "" {
q = q.Where("order_id = ?", orderID)
}
if paypalOrderID != "" {
q = q.Where("paypal_order_id = ?", paypalOrderID)
}
if status != "" {
q = q.Where("status = ?", status)
}
if refundStatus != "" {
q = q.Where("refund_status = ?", refundStatus)
}
if payerEmail != "" {
q = q.Where("payer_email LIKE ?", "%"+payerEmail+"%")
}
if start != "" && end != "" {
q = q.Where("created_at BETWEEN ? AND ?", start, end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count payment records: %w", err)
}
rows, err := q.Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query payment records: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id", "amount", "refunded_amount"}
timeCols := []string{"created_at", "updated_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
list = append(list, normalizeGenericInts(cols, vals, intCols, timeCols))
}
return &PageResult{List: list, Total: total}, nil
}
// RefundPaymentRecord 退款支付记录
func RefundPaymentRecord(orderID, paypalCaptureID string, amount *int64) error {
cfg := config.GetConfig()
payBaseURL := cfg.Pay.BaseURL
timeout := time.Duration(cfg.Pay.Timeout) * time.Second
// 构建请求体
reqBody := make(map[string]interface{})
if orderID != "" {
reqBody["order_id"] = orderID
}
if paypalCaptureID != "" {
reqBody["paypal_capture_id"] = paypalCaptureID
}
if amount != nil {
reqBody["amount"] = *amount
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("序列化请求失败: %w", err)
}
// 创建HTTP请求
req, err := http.NewRequest("POST", payBaseURL+"/api/refund", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("发送退款请求失败: %w", err)
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("退款接口返回错误: %d, 响应: %s", resp.StatusCode, string(body))
}
// 读取订单基础信息(一次查询拿到 user_id 和 amount
var rec struct {
UserID string `json:"user_id"`
Amount int64 `json:"amount"`
}
if err := storage.DB.Table("m_payment_records").Select("user_id, amount").Where("order_id = ?", orderID).Scan(&rec).Error; err != nil {
return nil // 不阻断主流程
}
// 确定应扣减的退款金额(单位:分)
var refundCents int64
if amount != nil {
refundCents = *amount
} else {
refundCents = rec.Amount
}
if refundCents <= 0 {
return nil
}
// 计算递减值Redis单位
decrBy := refundCents * 1_000_000 // 1 cent = 1e6 in Redis unit
// 连接 Redis 并扣减余额
rclient, err := pkgredis.NewClient(cfg.Redis)
if err != nil {
return nil // 忽略Redis失败
}
defer rclient.Close()
ctx := context.Background()
key := fmt.Sprintf("GW:QU_%s", rec.UserID)
_, _ = rclient.Rdb.DecrBy(ctx, key, decrBy).Result()
return nil
}
// ListMcpAccountRechargeRecords 查询MCP账号充值记录
func ListMcpAccountRechargeRecords(offset, limit int, provider, account, start, end string) (*PageResult, error) {
db := storage.GetPG()
if db == nil {
return nil, fmt.Errorf("PostgreSQL未初始化")
}
var total int64
q := db.Table("mcp_account_recharge_records").
Select(`
mcp_account_recharge_records.*,
mcp_providers.provider,
mcp_providers.account
`).
Joins("LEFT JOIN mcp_providers ON mcp_account_recharge_records.provider_id = mcp_providers.id")
if provider != "" {
q = q.Where("mcp_providers.provider ILIKE ?", "%"+provider+"%")
}
if account != "" {
q = q.Where("mcp_providers.account ILIKE ?", "%"+account+"%")
}
if start != "" {
q = q.Where("mcp_account_recharge_records.recharge_date >= ?", start)
}
if end != "" {
q = q.Where("mcp_account_recharge_records.recharge_date <= ?", end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count mcp account recharge records: %w", err)
}
rows, err := q.Order("mcp_account_recharge_records.created_at DESC").
Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query mcp account recharge records: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
floatCols := []string{"amount"}
timeCols := []string{"created_at", "updated_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
// 处理数据类型
m := map[string]interface{}{}
isFloat := map[string]struct{}{}
for _, k := range floatCols {
isFloat[k] = struct{}{}
}
isTime := map[string]struct{}{}
for _, k := range timeCols {
isTime[k] = struct{}{}
}
for i, c := range cols {
if _, ok := isFloat[c]; ok {
if f, ok2 := toFloat(vals[i]); ok2 {
m[c] = f
} else {
m[c] = nil
}
continue
}
if _, ok := isTime[c]; ok {
if iso, ok2 := toTimeISO(vals[i]); ok2 {
m[c] = iso
} else {
m[c] = nil
}
continue
}
m[c] = toString(vals[i])
}
list = append(list, m)
}
return &PageResult{List: list, Total: total}, nil
}
// CreateMcpAccountRechargeRecord 创建MCP账号充值记录
func CreateMcpAccountRechargeRecord(providerID string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error {
db := storage.GetPG()
if db == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
// 解析provider_id为UUID
providerUUID, err := uuid.Parse(providerID)
if err != nil {
return fmt.Errorf("无效的provider_id: %w", err)
}
// 解析日期
date, err := time.Parse("2006-01-02", rechargeDate)
if err != nil {
return fmt.Errorf("无效的日期格式: %w", err)
}
// 处理operator_id
var operatorUUID *uuid.UUID
if operatorID != nil {
var opID uuid.UUID
switch v := operatorID.(type) {
case string:
opID, err = uuid.Parse(v)
if err != nil {
// 如果不是UUID格式可能是数字ID设为nil
operatorUUID = nil
} else {
operatorUUID = &opID
}
case uuid.UUID:
operatorUUID = &v
default:
operatorUUID = nil
}
}
// 开始事务
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开始事务失败: %w", tx.Error)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 1. 创建充值记录
record := map[string]interface{}{
"provider_id": providerUUID,
"amount": amount,
"currency": "USD",
"recharge_date": date,
"operator_id": operatorUUID,
"operator_name": operatorName,
"remark": remark,
}
if err := tx.Table("mcp_account_recharge_records").Create(record).Error; err != nil {
tx.Rollback()
return fmt.Errorf("创建充值记录失败: %w", err)
}
amountDecimal := decimal.NewFromFloat(amount).Round(8)
remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName)
_, _, err = applyMcpBalanceChange(tx, providerUUID, amountDecimal, remarkText)
if err != nil {
tx.Rollback()
return fmt.Errorf("创建余额记录失败: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
// UpdateMcpAccountRechargeRecord 更新MCP账号充值记录
func UpdateMcpAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error {
db := storage.GetPG()
if db == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
recordID, err := uuid.Parse(id)
if err != nil {
return fmt.Errorf("无效的记录ID: %w", err)
}
updates := map[string]interface{}{}
if amount != nil {
updates["amount"] = *amount
}
if rechargeDate != nil {
date, err := time.Parse("2006-01-02", *rechargeDate)
if err != nil {
return fmt.Errorf("无效的日期格式: %w", err)
}
updates["recharge_date"] = date
}
if remark != nil {
updates["remark"] = *remark
}
if len(updates) == 0 {
return fmt.Errorf("没有需要更新的字段")
}
updates["updated_at"] = time.Now()
return db.Table("mcp_account_recharge_records").
Where("id = ?", recordID).
Updates(updates).Error
}
// DeleteMcpAccountRechargeRecord 删除MCP账号充值记录
func DeleteMcpAccountRechargeRecord(id string) error {
db := storage.GetPG()
if db == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
recordID, err := uuid.Parse(id)
if err != nil {
return fmt.Errorf("无效的记录ID: %w", err)
}
return db.Table("mcp_account_recharge_records").
Where("id = ?", recordID).
Delete(nil).Error
}
// GetMcpProviderAccounts 获取MCP账号列表用于下拉选择
func GetMcpProviderAccounts(status string, isUsed *bool) ([]map[string]interface{}, error) {
db := storage.GetPG()
if db == nil {
return nil, fmt.Errorf("PostgreSQL未初始化")
}
q := db.Table("mcp_providers").
Select("id, provider, account, status, is_used")
if status != "" {
q = q.Where("status = ?", status)
}
if isUsed != nil {
q = q.Where("is_used = ?", *isUsed)
}
var list []map[string]interface{}
rows, err := q.Order("provider, account").Rows()
if err != nil {
return nil, fmt.Errorf("query mcp providers: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
m := map[string]interface{}{}
for i, c := range cols {
m[c] = toString(vals[i])
}
list = append(list, m)
}
return list, nil
}
// GetMcpAccountLatestBalance 获取账户最新余额
func GetMcpAccountLatestBalance(providerID string) (map[string]interface{}, error) {
db := storage.GetPG()
if db == nil {
return nil, fmt.Errorf("PostgreSQL未初始化")
}
providerUUID, err := uuid.Parse(providerID)
if err != nil {
return nil, fmt.Errorf("无效的provider_id: %w", err)
}
var balance map[string]interface{}
err = db.Table("mcp_account_balances").
Select(`
mcp_account_balances.*,
mcp_providers.provider,
mcp_providers.account
`).
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
Where("mcp_account_balances.provider_id = ?", providerUUID).
Order("mcp_account_balances.created_at DESC").
Limit(1).
Find(&balance).Error
if err != nil {
// 如果没有记录返回空map而不是错误
return map[string]interface{}{}, nil
}
return balance, nil
}
// GetMcpAccountBalances 获取所有账户最新余额列表
func GetMcpAccountBalances() ([]map[string]interface{}, error) {
db := storage.GetPG()
if db == nil {
return nil, fmt.Errorf("PostgreSQL未初始化")
}
// 先获取每个账户最新的创建时间
var latestRecords []struct {
ProviderID uuid.UUID
MaxCreated time.Time
}
err := db.Table("mcp_account_balances").
Select("provider_id, MAX(created_at) as max_created").
Group("provider_id").
Find(&latestRecords).Error
if err != nil {
return nil, fmt.Errorf("查询余额失败: %w", err)
}
// 为每个账户获取最新余额
balances := make([]map[string]interface{}, 0)
for _, lr := range latestRecords {
var balance map[string]interface{}
err = db.Table("mcp_account_balances").
Select(`
mcp_account_balances.*,
mcp_providers.provider,
mcp_providers.account
`).
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
Where("mcp_account_balances.provider_id = ? AND mcp_account_balances.created_at = ?", lr.ProviderID, lr.MaxCreated).
Limit(1).
Find(&balance).Error
if err == nil && len(balance) > 0 {
balances = append(balances, balance)
}
}
// 处理数据类型
result := make([]map[string]interface{}, 0)
for _, bal := range balances {
m := make(map[string]interface{})
for k, v := range bal {
if k == "balance" {
if f, ok := toFloat(v); ok {
m[k] = f
} else {
m[k] = v
}
} else if k == "created_at" {
if iso, ok := toTimeISO(v); ok {
m[k] = iso
} else {
m[k] = toString(v)
}
} else {
m[k] = toString(v)
}
}
result = append(result, m)
}
return result, nil
}
// RunMcpUsageBalanceJob 汇总MCP调用费用并更新余额
// 根据上次处理的时间窗口自动计算本次需要处理的时间范围
func RunMcpUsageBalanceJob() error {
// 定时任务执行前刷新低余额阈值配置
RefreshLowBalanceThreshold()
mysqlDB := storage.DB
if mysqlDB == nil {
return fmt.Errorf("MySQL未初始化")
}
pg := storage.GetPG()
if pg == nil {
return fmt.Errorf("PostgreSQL未初始化")
}
// 获取上次成功处理的窗口结束时间
lastWindowEnd, err := getLastProcessedWindowEnd(pg, mcpUsageJobName)
if err != nil {
return fmt.Errorf("获取上次处理时间失败: %w", err)
}
// 计算当前时间前一个小时的整点时间
now := time.Now().UTC()
currentWindowEnd := now.Truncate(time.Hour)
var windowStart, windowEnd time.Time
if lastWindowEnd.IsZero() {
// 如果没有历史记录,只处理上一个小时的数据
windowEnd = currentWindowEnd
windowStart = windowEnd.Add(-time.Hour)
} else {
// 如果有历史记录,从上次处理的时间到当前时间前一个小时
windowStart = lastWindowEnd
windowEnd = currentWindowEnd
}
// 如果窗口为空例如上次处理到10:00当前也是10:00直接返回
if !windowStart.Before(windowEnd) {
return nil
}
windowStart = windowStart.UTC()
windowEnd = windowEnd.UTC()
// 加载所有 mcp_providers 记录到内存
providerMap, err := loadMcpProvidersToMemory(pg)
if err != nil {
return fmt.Errorf("加载账号信息失败: %w", err)
}
runID, err := createMcpUsageJobRun(pg, mcpUsageJobName, windowStart, windowEnd)
if err != nil {
if errors.Is(err, ErrMcpUsageJobAlreadyProcessed) || errors.Is(err, ErrMcpUsageJobAlreadyInProgress) {
return nil
}
return fmt.Errorf("创建任务记录失败: %w", err)
}
var usageRows []mcpUsageAggregation
err = mysqlDB.Table("mcp_invoke_usages").
Select("provider, account, SUM(cost) AS cost_sum").
Where("updated_at >= ? AND updated_at < ?", windowStart, windowEnd).
Group("provider, account").
Having("SUM(cost) <> 0").
Find(&usageRows).Error
if err != nil {
updateErr := updateMcpUsageJobRun(pg, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
if updateErr != nil {
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err)
}
return fmt.Errorf("聚合调用费用失败: %w", err)
}
if len(usageRows) == 0 {
return updateMcpUsageJobRun(pg, runID, jobStatusSuccess, 0, decimal.Zero, "")
}
var (
recordsProcessed int
totalCost = decimal.Zero
hadFailure bool
errorMessages []string
)
for _, row := range usageRows {
provider := strings.TrimSpace(row.Provider)
account := ""
if row.Account.Valid {
account = strings.TrimSpace(row.Account.String)
}
costCents := row.CostSum
costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor)
detailStatus := jobStatusSuccess
detailError := ""
var (
providerEntry mcpProviderCache
providerID uuid.UUID
actualCostUSD = costUSD
floatingRatio float64 = 1
)
if account == "" {
detailStatus = jobStatusFailed
detailError = "account为空"
} else {
if providerEntry, err = findMcpProviderFromMemory(providerMap, provider, account); err != nil {
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("查找账号失败: %v", err)
} else {
providerID = providerEntry.ID
floatingRatio = providerEntry.FloatingRatio
}
}
if detailStatus == jobStatusSuccess {
if floatingRatio <= 0 {
floatingRatio = 1
}
if floatingRatio != 1 {
actualCostUSD = costUSD.Div(decimal.NewFromFloat(floatingRatio))
}
actualCostUSD = actualCostUSD.Round(8)
if actualCostUSD.LessThanOrEqual(decimal.Zero) {
detailStatus = jobStatusFailed
detailError = "结算金额为0"
}
}
var previousBalance decimal.Decimal
var newBalance decimal.Decimal
if detailStatus == jobStatusSuccess {
tx := pg.Begin()
if tx.Error != nil {
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("开启事务失败: %v", tx.Error)
} else {
remark := fmt.Sprintf(
"自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)",
windowStart.Format("2006-01-02 15:04"),
windowEnd.Format("15:04"),
costUSD.StringFixed(8),
floatingRatio,
actualCostUSD.StringFixed(8),
)
previousBalance, newBalance, err = applyMcpBalanceChange(tx, providerID, actualCostUSD.Neg(), remark)
if err != nil {
tx.Rollback()
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("扣减余额失败: %v", err)
} else {
if commitErr := tx.Commit().Error; commitErr != nil {
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("提交事务失败: %v", commitErr)
}
}
}
}
if detailStatus == jobStatusSuccess {
recordsProcessed++
totalCost = totalCost.Add(actualCostUSD)
threshold := GetLowBalanceThreshold()
if newBalance.LessThan(threshold) {
notifier.NotifyMcpLowBalance(provider, account, newBalance, threshold)
}
} else {
hadFailure = true
errorMessages = append(errorMessages, fmt.Sprintf("%s/%s: %s", provider, account, detailError))
}
detailData := map[string]interface{}{
"run_id": runID,
"provider": provider,
"account": account,
"cost_cents": costCents,
"cost_usd": actualCostUSD,
"status": detailStatus,
"created_at": time.Now(),
"error_message": detailError,
}
if providerID != uuid.Nil {
detailData["provider_id"] = providerID
}
if detailStatus == jobStatusSuccess {
detailData["previous_balance"] = previousBalance
detailData["new_balance"] = newBalance
}
if err := pg.Table("mcp_usage_balance_job_details").Create(detailData).Error; err != nil {
hadFailure = true
errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s: %v", provider, account, err))
}
}
finalStatus := jobStatusSuccess
finalError := ""
if hadFailure {
finalStatus = jobStatusFailed
finalError = strings.Join(errorMessages, "; ")
if len(finalError) > 1000 {
finalError = finalError[:1000]
}
}
if err := updateMcpUsageJobRun(pg, runID, finalStatus, recordsProcessed, totalCost, finalError); err != nil {
return fmt.Errorf("更新任务记录失败: %w", err)
}
if hadFailure {
return fmt.Errorf("部分账户结算失败: %s", finalError)
}
return nil
}
// RunModelTokenBalanceJob 汇总模型Token消费并扣减余额
func RunModelTokenBalanceJob() error {
// 定时任务执行前刷新低余额阈值配置
RefreshLowBalanceThreshold()
mysqlDB := storage.DB
if mysqlDB == nil {
return fmt.Errorf("MySQL未初始化")
}
lastWindowEnd, err := getLastProcessedModelTokenWindowEnd(mysqlDB, modelTokenJobName)
if err != nil {
return fmt.Errorf("获取上次处理时间失败: %w", err)
}
now := time.Now().UTC()
currentWindowEnd := now.Truncate(time.Hour)
var windowStart, windowEnd time.Time
if lastWindowEnd.IsZero() {
windowEnd = currentWindowEnd
windowStart = windowEnd.Add(-time.Hour)
} else {
windowStart = lastWindowEnd
windowEnd = currentWindowEnd
}
if !windowStart.Before(windowEnd) {
return nil
}
windowStart = windowStart.UTC()
windowEnd = windowEnd.UTC()
runID, err := createModelTokenJobRun(mysqlDB, modelTokenJobName, windowStart, windowEnd)
if err != nil {
if errors.Is(err, ErrModelTokenJobAlreadyProcessed) || errors.Is(err, ErrModelTokenJobAlreadyInProgress) {
return nil
}
return fmt.Errorf("创建任务记录失败: %w", err)
}
configMap, err := loadModelConfigsToMemory(mysqlDB)
if err != nil {
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
if updateErr != nil {
return fmt.Errorf("加载模型配置失败: %v; 原始错误: %w", updateErr, err)
}
return fmt.Errorf("加载模型配置失败: %w", err)
}
startDay := windowStart.Format("2006-01-02")
startHour := windowStart.Hour()
endDay := windowEnd.Format("2006-01-02")
endHour := windowEnd.Hour()
query := mysqlDB.Table("gw_token_usages").
Select("provider, account, model, SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) AS total_cost").
Where("(day > ?) OR (day = ? AND hour >= ?)", startDay, startDay, startHour).
Where("(day < ?) OR (day = ? AND hour < ?)", endDay, endDay, endHour).
Group("provider, account, model").
Having("SUM(prompt_cost + completion_cost + cache_create_cost + cache_read_cost) <> 0")
rows, err := query.Rows()
if err != nil {
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, 0, decimal.Zero, err.Error())
if updateErr != nil {
return fmt.Errorf("query error update failed: %v; 原始错误: %w", updateErr, err)
}
return fmt.Errorf("聚合Token消费失败: %w", err)
}
defer rows.Close()
hasRows := false
var (
recordsProcessed int
totalCostUSD = decimal.Zero
hadFailure bool
errorMessages []string
)
for rows.Next() {
hasRows = true
var (
provider string
account sql.NullString
model string
totalCost sql.NullInt64
)
if scanErr := rows.Scan(&provider, &account, &model, &totalCost); scanErr != nil {
hadFailure = true
errorMsg := fmt.Sprintf("扫描聚合结果失败: %v", scanErr)
errorMessages = append(errorMessages, errorMsg)
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, errorMsg)
if updateErr != nil {
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, scanErr)
}
return fmt.Errorf("扫描聚合结果失败: %w", scanErr)
}
if !totalCost.Valid {
continue
}
costCents := totalCost.Int64
if costCents == 0 {
continue
}
provider = strings.TrimSpace(provider)
model = strings.TrimSpace(model)
accountStr := ""
if account.Valid {
accountStr = strings.TrimSpace(account.String)
}
costUSD := decimal.NewFromInt(costCents).Div(usdScalingFactor)
detailStatus := jobStatusSuccess
detailError := ""
var (
modelConfigID uint64
previousBalance decimal.Decimal
newBalance decimal.Decimal
)
configInfo, exists := configMap[makeModelConfigKey(provider, model)]
if !exists {
detailStatus = jobStatusFailed
detailError = "未找到匹配的模型配置"
} else {
modelConfigID = configInfo.ID
}
priceRatio := configInfo.PriceRatio
if priceRatio <= 0 {
priceRatio = 1
}
actualCostUSD := costUSD
if detailStatus == jobStatusSuccess {
if priceRatio != 1 {
actualCostUSD = costUSD.Div(decimal.NewFromFloat(priceRatio))
}
actualCostUSD = actualCostUSD.Round(8)
if actualCostUSD.LessThanOrEqual(decimal.Zero) {
detailStatus = jobStatusFailed
detailError = "结算金额为0"
}
}
if detailStatus == jobStatusSuccess {
tx := mysqlDB.Begin()
if tx.Error != nil {
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("开启事务失败: %v", tx.Error)
} else {
remark := fmt.Sprintf(
"自动扣费 %s - %s (原始$%s, 上浮%0.2f, 实际$%s)",
windowStart.Format("2006-01-02 15:04"),
windowEnd.Format("15:04"),
costUSD.StringFixed(8),
priceRatio,
actualCostUSD.StringFixed(8),
)
previousBalance, newBalance, err = applyModelAccountBalanceChange(tx, accountStr, actualCostUSD.Neg(), remark)
if err != nil {
tx.Rollback()
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("扣减余额失败: %v", err)
} else if commitErr := tx.Commit().Error; commitErr != nil {
detailStatus = jobStatusFailed
detailError = fmt.Sprintf("提交事务失败: %v", commitErr)
}
}
}
if detailStatus == jobStatusSuccess {
recordsProcessed++
totalCostUSD = totalCostUSD.Add(actualCostUSD)
threshold := GetLowBalanceThreshold()
if newBalance.LessThan(threshold) {
notifier.NotifyModelLowBalance(provider, accountStr, model, newBalance, threshold)
}
} else {
hadFailure = true
errorMessages = append(errorMessages, fmt.Sprintf("%s/%s/%s: %s", provider, accountStr, model, detailError))
}
detailData := map[string]interface{}{
"run_id": runID,
"provider": provider,
"account": accountStr,
"model": model,
"total_cost_cents": costCents,
"total_cost_usd": actualCostUSD,
"status": detailStatus,
"created_at": time.Now(),
"error_message": detailError,
}
if modelConfigID != 0 {
detailData["model_config_id"] = modelConfigID
}
if detailStatus == jobStatusSuccess {
detailData["previous_balance"] = previousBalance
detailData["new_balance"] = newBalance
}
if detailErr := mysqlDB.Table("model_token_balance_job_details").Create(detailData).Error; detailErr != nil {
hadFailure = true
errorMessages = append(errorMessages, fmt.Sprintf("记录明细失败 %s/%s/%s: %v", provider, accountStr, model, detailErr))
}
}
if err := rows.Err(); err != nil {
updateErr := updateModelTokenJobRun(mysqlDB, runID, jobStatusFailed, recordsProcessed, totalCostUSD, err.Error())
if updateErr != nil {
return fmt.Errorf("更新任务记录失败: %v; 原始错误: %w", updateErr, err)
}
return fmt.Errorf("遍历聚合结果失败: %w", err)
}
if !hasRows {
return updateModelTokenJobRun(mysqlDB, runID, jobStatusSuccess, 0, decimal.Zero, "")
}
finalStatus := jobStatusSuccess
finalError := ""
if hadFailure {
finalStatus = jobStatusFailed
finalError = strings.Join(errorMessages, "; ")
if len(finalError) > 1000 {
finalError = finalError[:1000]
}
}
if err := updateModelTokenJobRun(mysqlDB, runID, finalStatus, recordsProcessed, totalCostUSD, finalError); err != nil {
return fmt.Errorf("更新任务记录失败: %w", err)
}
if hadFailure {
return fmt.Errorf("部分模型结算失败: %s", finalError)
}
return nil
}
// GetMcpAccountBalanceHistory 获取账户余额历史记录
func GetMcpAccountBalanceHistory(providerID string, start, end string) ([]map[string]interface{}, error) {
db := storage.GetPG()
if db == nil {
return nil, fmt.Errorf("PostgreSQL未初始化")
}
providerUUID, err := uuid.Parse(providerID)
if err != nil {
return nil, fmt.Errorf("无效的provider_id: %w", err)
}
q := db.Table("mcp_account_balances").
Select(`
mcp_account_balances.*,
mcp_providers.provider,
mcp_providers.account
`).
Joins("LEFT JOIN mcp_providers ON mcp_account_balances.provider_id = mcp_providers.id").
Where("mcp_account_balances.provider_id = ?", providerUUID)
if start != "" {
q = q.Where("mcp_account_balances.created_at >= ?", start)
}
if end != "" {
q = q.Where("mcp_account_balances.created_at <= ?", end)
}
rows, err := q.Order("mcp_account_balances.created_at DESC").Rows()
if err != nil {
return nil, fmt.Errorf("查询余额历史失败: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
floatCols := []string{"balance"}
timeCols := []string{"created_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
m := map[string]interface{}{}
isFloat := map[string]struct{}{}
for _, k := range floatCols {
isFloat[k] = struct{}{}
}
isTime := map[string]struct{}{}
for _, k := range timeCols {
isTime[k] = struct{}{}
}
for i, c := range cols {
if _, ok := isFloat[c]; ok {
if f, ok2 := toFloat(vals[i]); ok2 {
m[c] = f
} else {
m[c] = nil
}
continue
}
if _, ok := isTime[c]; ok {
if iso, ok2 := toTimeISO(vals[i]); ok2 {
m[c] = iso
} else {
m[c] = toString(vals[i])
}
continue
}
m[c] = toString(vals[i])
}
list = append(list, m)
}
return list, nil
}
// ========== 模型账号充值记录和余额管理 ==========
// applyModelAccountBalanceChange 应用模型账号余额变更(追加式记录)
func applyModelAccountBalanceChange(tx *gorm.DB, account string, delta decimal.Decimal, remark string) (decimal.Decimal, decimal.Decimal, error) {
var latest struct {
Balance decimal.Decimal `gorm:"column:balance"`
}
prev := decimal.Zero
err := tx.Table("gw_model_account_balances").
Select("balance").
Where("account = ?", account).
Order("created_at DESC").
Take(&latest).Error
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return decimal.Zero, decimal.Zero, err
}
} else {
prev = latest.Balance
}
newBalance := prev.Add(delta)
record := map[string]interface{}{
"account": account,
"balance": newBalance,
"currency": "USD",
"remark": remark,
"created_at": time.Now(),
}
if err := tx.Table("gw_model_account_balances").Create(record).Error; err != nil {
return decimal.Zero, decimal.Zero, err
}
return prev, newBalance, nil
}
// ListModelAccountRechargeRecords 查询模型账号充值记录
func ListModelAccountRechargeRecords(offset, limit int, provider, modelName, start, end string) (*PageResult, error) {
db := storage.DB
if db == nil {
return nil, fmt.Errorf("MySQL未初始化")
}
var total int64
q := db.Table("gw_model_account_recharge_records r").
Select(`
r.*,
gp.name AS provider
`).
Joins("LEFT JOIN gw_providers gp ON r.account = gp.account").
Where("r.deleted_at IS NULL")
if provider != "" {
q = q.Where("gp.name LIKE ?", "%"+provider+"%")
}
if start != "" {
q = q.Where("r.recharge_date >= ?", start)
}
if end != "" {
q = q.Where("r.recharge_date <= ?", end)
}
if err := q.Count(&total).Error; err != nil {
return nil, fmt.Errorf("count model account recharge records: %w", err)
}
rows, err := q.Order("r.created_at DESC").
Offset(offset).Limit(limit).Rows()
if err != nil {
return nil, fmt.Errorf("query model account recharge records: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id", "operator_id"}
floatCols := []string{"amount"}
timeCols := []string{"created_at", "updated_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
// 需要处理float类型
m := normalizeGenericInts(cols, vals, intCols, timeCols)
// 手动处理float字段
for i, c := range cols {
for _, fc := range floatCols {
if c == fc {
if f, ok := toFloat(vals[i]); ok {
m[c] = f
}
break
}
}
}
list = append(list, m)
}
return &PageResult{List: list, Total: total}, nil
}
// CreateModelAccountRechargeRecord 创建模型账号充值记录
func CreateModelAccountRechargeRecord(account string, amount float64, rechargeDate string, operatorID interface{}, operatorName, remark string) error {
db := storage.DB
if db == nil {
return fmt.Errorf("MySQL未初始化")
}
account = strings.TrimSpace(account)
if account == "" {
return fmt.Errorf("账号不能为空")
}
// 验证账号是否存在
var count int64
if err := db.Table("gw_providers").Where("account = ?", account).Count(&count).Error; err != nil {
return fmt.Errorf("验证账号失败: %w", err)
}
if count == 0 {
return fmt.Errorf("账号不存在: %s", account)
}
// 解析日期
date, err := time.Parse("2006-01-02", rechargeDate)
if err != nil {
return fmt.Errorf("无效的日期格式: %w", err)
}
// 处理operator_id
var operatorIDUint *uint64
if operatorID != nil {
var opID uint64
switch v := operatorID.(type) {
case string:
parsed, err := strconv.ParseUint(v, 10, 64)
if err != nil {
operatorIDUint = nil
} else {
opID = parsed
operatorIDUint = &opID
}
case uint64:
operatorIDUint = &v
case int64:
u := uint64(v)
operatorIDUint = &u
case int:
u := uint64(v)
operatorIDUint = &u
default:
operatorIDUint = nil
}
}
// 开始事务
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("开始事务失败: %w", tx.Error)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 1. 创建充值记录
now := time.Now()
record := map[string]interface{}{
"account": account,
"amount": amount,
"currency": "USD",
"recharge_date": date,
"operator_id": operatorIDUint,
"operator_name": operatorName,
"remark": remark,
"created_at": now,
"updated_at": now,
}
if err := tx.Table("gw_model_account_recharge_records").Create(record).Error; err != nil {
tx.Rollback()
return fmt.Errorf("创建充值记录失败: %w", err)
}
amountDecimal := decimal.NewFromFloat(amount).Round(8)
remarkText := fmt.Sprintf("手动充值 $%s by %s", amountDecimal.StringFixed(2), operatorName)
_, _, err = applyModelAccountBalanceChange(tx, account, amountDecimal, remarkText)
if err != nil {
tx.Rollback()
return fmt.Errorf("创建余额记录失败: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
// UpdateModelAccountRechargeRecord 更新模型账号充值记录
func UpdateModelAccountRechargeRecord(id string, amount *float64, rechargeDate *string, remark *string) error {
db := storage.DB
if db == nil {
return fmt.Errorf("MySQL未初始化")
}
recordID, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return fmt.Errorf("无效的记录ID: %w", err)
}
updates := map[string]interface{}{}
if amount != nil {
updates["amount"] = *amount
}
if rechargeDate != nil {
date, err := time.Parse("2006-01-02", *rechargeDate)
if err != nil {
return fmt.Errorf("无效的日期格式: %w", err)
}
updates["recharge_date"] = date
}
if remark != nil {
updates["remark"] = *remark
}
if len(updates) == 0 {
return fmt.Errorf("没有需要更新的字段")
}
updates["updated_at"] = time.Now()
return db.Table("gw_model_account_recharge_records").
Where("id = ? AND deleted_at IS NULL", recordID).
Updates(updates).Error
}
// DeleteModelAccountRechargeRecord 删除模型账号充值记录(软删除)
func DeleteModelAccountRechargeRecord(id string) error {
db := storage.DB
if db == nil {
return fmt.Errorf("MySQL未初始化")
}
recordID, err := strconv.ParseUint(id, 10, 64)
if err != nil {
return fmt.Errorf("无效的记录ID: %w", err)
}
return db.Table("gw_model_account_recharge_records").
Where("id = ?", recordID).
Update("deleted_at", time.Now()).Error
}
// GetModelConfigAccounts 获取模型配置列表(用于下拉选择)
func GetModelConfigAccounts(enabled *bool) ([]map[string]interface{}, error) {
db := storage.DB
if db == nil {
return nil, fmt.Errorf("MySQL未初始化")
}
q := db.Table("gw_providers").
Select("name, api_type, account, status, priority").
Where("deleted_at IS NULL")
if enabled != nil {
if *enabled {
q = q.Where("status = ?", "active")
} else {
q = q.Where("status <> ?", "active")
}
}
var accounts []map[string]interface{}
rows, err := q.Order("name, account").Rows()
if err != nil {
return nil, fmt.Errorf("query model accounts: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
m := make(map[string]interface{})
for i, c := range cols {
switch strings.ToLower(c) {
case "priority":
if val, ok := toInt(vals[i]); ok {
m[c] = val
} else {
m[c] = nil
}
default:
m[c] = toString(vals[i])
}
}
accounts = append(accounts, m)
}
return accounts, nil
}
// GetModelAccountLatestBalance 获取模型账号最新余额
func GetModelAccountLatestBalance(account string) (map[string]interface{}, error) {
db := storage.DB
if db == nil {
return nil, fmt.Errorf("MySQL未初始化")
}
account = strings.TrimSpace(account)
if account == "" {
return nil, fmt.Errorf("账号不能为空")
}
var balance map[string]interface{}
err := db.Table("gw_model_account_balances b").
Select(`
b.*,
gp.name AS provider,
gp.api_type
`).
Joins("LEFT JOIN gw_providers gp ON b.account = gp.account").
Where("b.account = ?", account).
Order("b.created_at DESC").
Limit(1).
Find(&balance).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return map[string]interface{}{}, nil
}
return nil, err
}
// 处理数据类型
result := make(map[string]interface{})
for k, v := range balance {
if k == "balance" {
if f, ok := toFloat(v); ok {
result[k] = f
} else {
result[k] = v
}
} else if k == "created_at" {
if iso, ok := toTimeISO(v); ok {
result[k] = iso
} else {
result[k] = toString(v)
}
} else {
result[k] = toString(v)
}
}
return result, nil
}
// GetModelAccountBalances 获取所有模型账号最新余额列表
func GetModelAccountBalances() ([]map[string]interface{}, error) {
db := storage.DB
if db == nil {
return nil, fmt.Errorf("MySQL未初始化")
}
// 先获取每个账号最新的创建时间
var latestRecords []struct {
Account string
MaxCreated time.Time
}
err := db.Table("gw_model_account_balances").
Select("account, MAX(created_at) as max_created").
Group("account").
Find(&latestRecords).Error
if err != nil {
return nil, fmt.Errorf("查询余额失败: %w", err)
}
// 为每个账号获取最新余额
balances := make([]map[string]interface{}, 0)
for _, lr := range latestRecords {
var balance map[string]interface{}
err = db.Table("gw_model_account_balances b").
Select(`
b.*,
gp.name AS provider,
gp.api_type
`).
Joins("LEFT JOIN gw_providers gp ON b.account = gp.account").
Where("b.account = ? AND b.created_at = ?", lr.Account, lr.MaxCreated).
Limit(1).
Find(&balance).Error
if err == nil && len(balance) > 0 {
balances = append(balances, balance)
}
}
// 处理数据类型
result := make([]map[string]interface{}, 0)
for _, bal := range balances {
m := make(map[string]interface{})
for k, v := range bal {
if k == "balance" {
if f, ok := toFloat(v); ok {
m[k] = f
} else {
m[k] = v
}
} else if k == "created_at" {
if iso, ok := toTimeISO(v); ok {
m[k] = iso
} else {
m[k] = toString(v)
}
} else {
m[k] = toString(v)
}
}
result = append(result, m)
}
return result, nil
}
// GetModelAccountBalanceHistory 获取模型账号余额历史
func GetModelAccountBalanceHistory(account string, start, end string) ([]map[string]interface{}, error) {
db := storage.DB
if db == nil {
return nil, fmt.Errorf("MySQL未初始化")
}
q := db.Table("gw_model_account_balances").
Where("account = ?", account)
if start != "" {
q = q.Where("created_at >= ?", start)
}
if end != "" {
q = q.Where("created_at <= ?", end)
}
rows, err := q.Order("created_at DESC").Rows()
if err != nil {
return nil, fmt.Errorf("query balance history: %w", err)
}
defer rows.Close()
cols, _ := rows.Columns()
list := make([]map[string]interface{}, 0)
intCols := []string{"id"}
floatCols := []string{"balance"}
timeCols := []string{"created_at"}
for rows.Next() {
vals := make([]interface{}, len(cols))
valPtrs := make([]interface{}, len(cols))
for i := range vals {
valPtrs[i] = &vals[i]
}
if err := rows.Scan(valPtrs...); err != nil {
return nil, err
}
m := normalizeGenericInts(cols, vals, intCols, timeCols)
// 手动处理float字段
for i, c := range cols {
for _, fc := range floatCols {
if c == fc {
if f, ok := toFloat(vals[i]); ok {
m[c] = f
}
break
}
}
}
list = append(list, m)
}
return list, nil
}