171 lines
4.8 KiB
Go
171 lines
4.8 KiB
Go
package handlers
|
||
|
||
import (
|
||
"goalfymax-admin/pkg/utils"
|
||
"strconv"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// VendorModelPricingHandler 供应商模型价格配置处理器
|
||
type VendorModelPricingHandler struct {
|
||
db *gorm.DB
|
||
response *utils.Response
|
||
}
|
||
|
||
// NewVendorModelPricingHandler 创建处理器
|
||
func NewVendorModelPricingHandler(db *gorm.DB) *VendorModelPricingHandler {
|
||
return &VendorModelPricingHandler{
|
||
db: db,
|
||
response: utils.NewResponse(),
|
||
}
|
||
}
|
||
|
||
// VendorModelPricingResponse 供应商模型价格配置响应
|
||
type VendorModelPricingResponse struct {
|
||
ID uint `json:"id"`
|
||
Provider string `json:"provider"`
|
||
Account string `json:"account"`
|
||
ModelName string `json:"model_name"`
|
||
InputPrice float64 `json:"input_price"`
|
||
OutputPrice float64 `json:"output_price"`
|
||
CacheReadPrice float64 `json:"cache_read_price"`
|
||
CacheCreatePrice float64 `json:"cache_create_price"`
|
||
PriceRatio float64 `json:"price_ratio"`
|
||
Enabled bool `json:"enabled"`
|
||
CreatedAt string `json:"created_at"`
|
||
UpdatedAt string `json:"updated_at"`
|
||
}
|
||
|
||
// PriceUpdateRequest 价格更新请求
|
||
type PriceUpdateRequest struct {
|
||
InputPrice *float64 `json:"input_price"`
|
||
OutputPrice *float64 `json:"output_price"`
|
||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||
CacheCreatePrice *float64 `json:"cache_create_price"`
|
||
Enabled *bool `json:"enabled"`
|
||
PriceRatio *float64 `json:"price_ratio"`
|
||
}
|
||
|
||
// GetVendorModelPricing 获取供应商模型价格配置列表
|
||
func (h *VendorModelPricingHandler) GetVendorModelPricing(c *gin.Context) {
|
||
var models []VendorModelPricingResponse
|
||
|
||
// 构建查询条件
|
||
// 使用 v2 表:gw_model_config_v2,直接返回数据,无需联查
|
||
// 为兼容前端字段类型,account 字段返回空串
|
||
query := h.db.Table("gw_model_config_v2 mc").
|
||
Select("mc.id, mc.provider, '' as account, mc.model_name, " +
|
||
"mc.prompt_price as input_price, " +
|
||
"mc.output_price as output_price, " +
|
||
"mc.cache_read_price, mc.cache_create_price, " +
|
||
"mc.price_ratio, " +
|
||
"mc.enabled, mc.created_at, mc.updated_at")
|
||
|
||
// 添加筛选条件
|
||
if provider := c.Query("provider"); provider != "" {
|
||
query = query.Where("mc.provider = ?", provider)
|
||
}
|
||
if model := c.Query("model"); model != "" {
|
||
query = query.Where("mc.model_name LIKE ?", "%"+model+"%")
|
||
}
|
||
if status := c.Query("status"); status != "" {
|
||
if status == "enabled" {
|
||
query = query.Where("mc.enabled = ?", true)
|
||
} else if status == "disabled" {
|
||
query = query.Where("mc.enabled = ?", false)
|
||
}
|
||
}
|
||
|
||
// 分页
|
||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||
size, _ := strconv.Atoi(c.DefaultQuery("size", "20"))
|
||
offset := (page - 1) * size
|
||
|
||
// 先获取总数(在应用分页之前)
|
||
var total int64
|
||
err := query.Count(&total).Error
|
||
if err != nil {
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 然后获取分页数据
|
||
err = query.Offset(offset).Limit(size).Find(&models).Error
|
||
if err != nil {
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Page(c, models, total, page, size)
|
||
}
|
||
|
||
// UpdateModelPricing 更新单个模型价格
|
||
func (h *VendorModelPricingHandler) UpdateModelPricing(c *gin.Context) {
|
||
id := c.Param("id")
|
||
|
||
var req PriceUpdateRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
h.response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 构建更新字段(v2 列)
|
||
updates := make(map[string]interface{})
|
||
if req.InputPrice != nil {
|
||
updates["prompt_price"] = *req.InputPrice
|
||
}
|
||
if req.OutputPrice != nil {
|
||
updates["output_price"] = *req.OutputPrice
|
||
}
|
||
if req.CacheReadPrice != nil {
|
||
updates["cache_read_price"] = *req.CacheReadPrice
|
||
}
|
||
if req.CacheCreatePrice != nil {
|
||
updates["cache_create_price"] = *req.CacheCreatePrice
|
||
}
|
||
if req.Enabled != nil {
|
||
updates["enabled"] = *req.Enabled
|
||
}
|
||
if req.PriceRatio != nil {
|
||
updates["price_ratio"] = *req.PriceRatio
|
||
}
|
||
|
||
// 检查模型是否存在(v2 表)
|
||
var count int64
|
||
err := h.db.Table("gw_model_config_v2").Where("id = ?", id).Count(&count).Error
|
||
if err != nil {
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
if count == 0 {
|
||
h.response.NotFound(c, "模型不存在")
|
||
return
|
||
}
|
||
|
||
// 更新模型价格(v2 表)
|
||
err = h.db.Table("gw_model_config_v2").Where("id = ?", id).Updates(updates).Error
|
||
if err != nil {
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, "价格更新成功")
|
||
}
|
||
|
||
// GetProviders 获取供应商列表(用于筛选)
|
||
func (h *VendorModelPricingHandler) GetProviders(c *gin.Context) {
|
||
var providers []string
|
||
err := h.db.Table("gw_providers").
|
||
Select("DISTINCT name").
|
||
Where("status = ?", "active").
|
||
Find(&providers).Error
|
||
if err != nil {
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, providers)
|
||
}
|