Files
goalfylearning-admin/internal/api/handlers/vendor_model_pricing_handler.go

171 lines
4.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handlers
import (
"goalfymax-admin/pkg/utils"
"strconv"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// VendorModelPricingHandler 供应商模型价格配置处理器
type VendorModelPricingHandler struct {
db *gorm.DB
response *utils.Response
}
// NewVendorModelPricingHandler 创建处理器
func NewVendorModelPricingHandler(db *gorm.DB) *VendorModelPricingHandler {
return &VendorModelPricingHandler{
db: db,
response: utils.NewResponse(),
}
}
// VendorModelPricingResponse 供应商模型价格配置响应
type VendorModelPricingResponse struct {
ID uint `json:"id"`
Provider string `json:"provider"`
Account string `json:"account"`
ModelName string `json:"model_name"`
InputPrice float64 `json:"input_price"`
OutputPrice float64 `json:"output_price"`
CacheReadPrice float64 `json:"cache_read_price"`
CacheCreatePrice float64 `json:"cache_create_price"`
PriceRatio float64 `json:"price_ratio"`
Enabled bool `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// PriceUpdateRequest 价格更新请求
type PriceUpdateRequest struct {
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
CacheCreatePrice *float64 `json:"cache_create_price"`
Enabled *bool `json:"enabled"`
PriceRatio *float64 `json:"price_ratio"`
}
// GetVendorModelPricing 获取供应商模型价格配置列表
func (h *VendorModelPricingHandler) GetVendorModelPricing(c *gin.Context) {
var models []VendorModelPricingResponse
// 构建查询条件
// 使用 v2 表gw_model_config_v2直接返回数据无需联查
// 为兼容前端字段类型account 字段返回空串
query := h.db.Table("gw_model_config_v2 mc").
Select("mc.id, mc.provider, '' as account, mc.model_name, " +
"mc.prompt_price as input_price, " +
"mc.output_price as output_price, " +
"mc.cache_read_price, mc.cache_create_price, " +
"mc.price_ratio, " +
"mc.enabled, mc.created_at, mc.updated_at")
// 添加筛选条件
if provider := c.Query("provider"); provider != "" {
query = query.Where("mc.provider = ?", provider)
}
if model := c.Query("model"); model != "" {
query = query.Where("mc.model_name LIKE ?", "%"+model+"%")
}
if status := c.Query("status"); status != "" {
if status == "enabled" {
query = query.Where("mc.enabled = ?", true)
} else if status == "disabled" {
query = query.Where("mc.enabled = ?", false)
}
}
// 分页
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
size, _ := strconv.Atoi(c.DefaultQuery("size", "20"))
offset := (page - 1) * size
// 先获取总数(在应用分页之前)
var total int64
err := query.Count(&total).Error
if err != nil {
h.response.InternalServerError(c, err.Error())
return
}
// 然后获取分页数据
err = query.Offset(offset).Limit(size).Find(&models).Error
if err != nil {
h.response.InternalServerError(c, err.Error())
return
}
h.response.Page(c, models, total, page, size)
}
// UpdateModelPricing 更新单个模型价格
func (h *VendorModelPricingHandler) UpdateModelPricing(c *gin.Context) {
id := c.Param("id")
var req PriceUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.response.BadRequest(c, err.Error())
return
}
// 构建更新字段v2 列)
updates := make(map[string]interface{})
if req.InputPrice != nil {
updates["prompt_price"] = *req.InputPrice
}
if req.OutputPrice != nil {
updates["output_price"] = *req.OutputPrice
}
if req.CacheReadPrice != nil {
updates["cache_read_price"] = *req.CacheReadPrice
}
if req.CacheCreatePrice != nil {
updates["cache_create_price"] = *req.CacheCreatePrice
}
if req.Enabled != nil {
updates["enabled"] = *req.Enabled
}
if req.PriceRatio != nil {
updates["price_ratio"] = *req.PriceRatio
}
// 检查模型是否存在v2 表)
var count int64
err := h.db.Table("gw_model_config_v2").Where("id = ?", id).Count(&count).Error
if err != nil {
h.response.InternalServerError(c, err.Error())
return
}
if count == 0 {
h.response.NotFound(c, "模型不存在")
return
}
// 更新模型价格v2 表)
err = h.db.Table("gw_model_config_v2").Where("id = ?", id).Updates(updates).Error
if err != nil {
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, "价格更新成功")
}
// GetProviders 获取供应商列表(用于筛选)
func (h *VendorModelPricingHandler) GetProviders(c *gin.Context) {
var providers []string
err := h.db.Table("gw_providers").
Select("DISTINCT name").
Where("status = ?", "active").
Find(&providers).Error
if err != nil {
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, providers)
}