package handlers import ( "goalfymax-admin/pkg/utils" "strconv" "github.com/gin-gonic/gin" "gorm.io/gorm" ) // VendorModelPricingHandler 供应商模型价格配置处理器 type VendorModelPricingHandler struct { db *gorm.DB response *utils.Response } // NewVendorModelPricingHandler 创建处理器 func NewVendorModelPricingHandler(db *gorm.DB) *VendorModelPricingHandler { return &VendorModelPricingHandler{ db: db, response: utils.NewResponse(), } } // VendorModelPricingResponse 供应商模型价格配置响应 type VendorModelPricingResponse struct { ID uint `json:"id"` Provider string `json:"provider"` Account string `json:"account"` ModelName string `json:"model_name"` InputPrice float64 `json:"input_price"` OutputPrice float64 `json:"output_price"` CacheReadPrice float64 `json:"cache_read_price"` CacheCreatePrice float64 `json:"cache_create_price"` PriceRatio float64 `json:"price_ratio"` Enabled bool `json:"enabled"` CreatedAt string `json:"created_at"` UpdatedAt string `json:"updated_at"` } // PriceUpdateRequest 价格更新请求 type PriceUpdateRequest struct { InputPrice *float64 `json:"input_price"` OutputPrice *float64 `json:"output_price"` CacheReadPrice *float64 `json:"cache_read_price"` CacheCreatePrice *float64 `json:"cache_create_price"` Enabled *bool `json:"enabled"` PriceRatio *float64 `json:"price_ratio"` } // GetVendorModelPricing 获取供应商模型价格配置列表 func (h *VendorModelPricingHandler) GetVendorModelPricing(c *gin.Context) { var models []VendorModelPricingResponse // 构建查询条件 // 使用 v2 表:gw_model_config_v2,直接返回数据,无需联查 // 为兼容前端字段类型,account 字段返回空串 query := h.db.Table("gw_model_config_v2 mc"). Select("mc.id, mc.provider, '' as account, mc.model_name, " + "mc.prompt_price as input_price, " + "mc.output_price as output_price, " + "mc.cache_read_price, mc.cache_create_price, " + "mc.price_ratio, " + "mc.enabled, mc.created_at, mc.updated_at") // 添加筛选条件 if provider := c.Query("provider"); provider != "" { query = query.Where("mc.provider = ?", provider) } if model := c.Query("model"); model != "" { query = query.Where("mc.model_name LIKE ?", "%"+model+"%") } if status := c.Query("status"); status != "" { if status == "enabled" { query = query.Where("mc.enabled = ?", true) } else if status == "disabled" { query = query.Where("mc.enabled = ?", false) } } // 分页 page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) size, _ := strconv.Atoi(c.DefaultQuery("size", "20")) offset := (page - 1) * size // 先获取总数(在应用分页之前) var total int64 err := query.Count(&total).Error if err != nil { h.response.InternalServerError(c, err.Error()) return } // 然后获取分页数据 err = query.Offset(offset).Limit(size).Find(&models).Error if err != nil { h.response.InternalServerError(c, err.Error()) return } h.response.Page(c, models, total, page, size) } // UpdateModelPricing 更新单个模型价格 func (h *VendorModelPricingHandler) UpdateModelPricing(c *gin.Context) { id := c.Param("id") var req PriceUpdateRequest if err := c.ShouldBindJSON(&req); err != nil { h.response.BadRequest(c, err.Error()) return } // 构建更新字段(v2 列) updates := make(map[string]interface{}) if req.InputPrice != nil { updates["prompt_price"] = *req.InputPrice } if req.OutputPrice != nil { updates["output_price"] = *req.OutputPrice } if req.CacheReadPrice != nil { updates["cache_read_price"] = *req.CacheReadPrice } if req.CacheCreatePrice != nil { updates["cache_create_price"] = *req.CacheCreatePrice } if req.Enabled != nil { updates["enabled"] = *req.Enabled } if req.PriceRatio != nil { updates["price_ratio"] = *req.PriceRatio } // 检查模型是否存在(v2 表) var count int64 err := h.db.Table("gw_model_config_v2").Where("id = ?", id).Count(&count).Error if err != nil { h.response.InternalServerError(c, err.Error()) return } if count == 0 { h.response.NotFound(c, "模型不存在") return } // 更新模型价格(v2 表) err = h.db.Table("gw_model_config_v2").Where("id = ?", id).Updates(updates).Error if err != nil { h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, "价格更新成功") } // GetProviders 获取供应商列表(用于筛选) func (h *VendorModelPricingHandler) GetProviders(c *gin.Context) { var providers []string err := h.db.Table("gw_providers"). Select("DISTINCT name"). Where("status = ?", "active"). Find(&providers).Error if err != nil { h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, providers) }