package notifier import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "net" "net/http" "net/url" "os" "strings" "sync" "text/template" "time" "goalfymax-admin/internal/config" "goalfymax-admin/pkg/utils" "github.com/shopspring/decimal" "go.uber.org/zap" ) type accountType string const ( accountTypeMCP accountType = "MCP账号" accountTypeModel accountType = "模型账号" ) var ( notifierMu sync.RWMutex dingTalkService *dingTalkNotifier ) // Init 初始化通知器 func Init(alertCfg config.AlertConfig, env string, logger *utils.Logger) { notifierMu.Lock() defer notifierMu.Unlock() if alertCfg.DingTalk.TimeoutSeconds <= 0 { alertCfg.DingTalk.TimeoutSeconds = 5 } if !alertCfg.DingTalk.Enabled { logger.Info("DingTalk notifier disabled") dingTalkService = nil return } service := newDingTalkNotifier(alertCfg.DingTalk, env, logger) dingTalkService = service logger.Info("DingTalk notifier initialized", zap.String("env", env), zap.String("webhook_hint", maskWebhook(alertCfg.DingTalk.Webhook))) } // NotifyMcpLowBalance MCP账号余额不足通知 func NotifyMcpLowBalance(provider, account string, balance, threshold decimal.Decimal) { notifyLowBalance(accountTypeMCP, provider, account, "", balance, threshold) } // NotifyModelLowBalance 模型账号余额不足通知 func NotifyModelLowBalance(provider, account, model string, balance, threshold decimal.Decimal) { notifyLowBalance(accountTypeModel, provider, account, model, balance, threshold) } func notifyLowBalance(accType accountType, provider, account, model string, balance, threshold decimal.Decimal) { notifierMu.RLock() service := dingTalkService notifierMu.RUnlock() if service == nil { return } if err := service.sendLowBalanceAlert(accType, provider, account, model, balance, threshold); err != nil { service.logger.Error("发送钉钉余额告警失败", zap.Error(err), zap.String("account_type", string(accType)), zap.String("provider", provider), zap.String("account", account), zap.String("model", model)) } } type dingTalkNotifier struct { webhook string secret string env string host string client *http.Client keyword string logger *utils.Logger tpl *template.Template } type dingTalkMarkdownPayload struct { MsgType string `json:"msgtype"` Markdown struct { Title string `json:"title"` Text string `json:"text"` } `json:"markdown"` } type dingTalkResponse struct { ErrCode int `json:"errcode"` ErrMsg string `json:"errmsg"` } const lowBalanceMarkdownTemplate = `[自动生成] {{.Emoji}} **{{.DisplayTitle}}** > **级别**: {{.SeverityLabel}} > **环境**: {{.Environment}} > **主机**: {{.Host}} > > **服务**: {{.Service}} - **时间**:{{.Timestamp}} - **Provider**:{{.Provider}} - **账号**:{{.Account}} {{- if .Model }} - **模型**:{{.Model}} {{- end }} - **当前余额**:${{.Balance}} - **阈值**:${{.Threshold}} ` type lowBalanceTemplateData struct { Emoji string DisplayTitle string SeverityLabel string Environment string Host string Service string Provider string Account string Model string Balance string Threshold string Timestamp string } type severityMetaEntry struct { Emoji string Label string } var ( severityCritical = severityMetaEntry{Emoji: "🚨", Label: "严重告警"} defaultService = "balance-monitor" ) func newDingTalkNotifier(cfg config.DingTalkConfig, env string, logger *utils.Logger) *dingTalkNotifier { timeout := time.Duration(cfg.TimeoutSeconds) * time.Second if timeout <= 0 { timeout = 5 * time.Second } return &dingTalkNotifier{ webhook: strings.TrimSpace(cfg.Webhook), secret: strings.TrimSpace(cfg.Secret), env: strings.TrimSpace(env), host: detectHost(), client: &http.Client{ Timeout: timeout, }, keyword: strings.TrimSpace(cfg.Keyword), logger: logger, tpl: template.Must(template.New("lowBalance").Parse(lowBalanceMarkdownTemplate)), } } func (n *dingTalkNotifier) sendLowBalanceAlert(accType accountType, provider, account, model string, balance, threshold decimal.Decimal) error { if strings.TrimSpace(n.webhook) == "" { return fmt.Errorf("dingtalk webhook 未配置") } accountLabel := strings.TrimSpace(account) if accountLabel == "" { accountLabel = "(空)" } title := fmt.Sprintf("%s余额告警", accType) displayTitle := title if keyword := strings.TrimSpace(n.keyword); keyword != "" { if !strings.Contains(title, keyword) { title = fmt.Sprintf("%s %s", keyword, title) } if !strings.Contains(displayTitle, keyword) { displayTitle = fmt.Sprintf("%s %s", keyword, displayTitle) } } data := lowBalanceTemplateData{ Emoji: severityCritical.Emoji, DisplayTitle: displayTitle, SeverityLabel: severityCritical.Label, Environment: valueOrFallback(n.env, "unknown"), Host: valueOrFallback(n.host, "unknown-host"), Service: defaultService, Provider: valueOrFallback(provider, "unknown"), Account: accountLabel, Model: strings.TrimSpace(model), Balance: balance.StringFixed(4), Threshold: threshold.StringFixed(4), Timestamp: time.Now().Format("2006-01-02 15:04:05"), } var buf bytes.Buffer if err := n.tpl.Execute(&buf, data); err != nil { return fmt.Errorf("渲染模板失败: %w", err) } return n.sendMarkdown(context.Background(), title, buf.String()) } func (n *dingTalkNotifier) sendMarkdown(ctx context.Context, title, markdown string) error { webhookURL, err := n.webhookWithSignature() if err != nil { return fmt.Errorf("生成签名失败: %w", err) } payload := dingTalkMarkdownPayload{MsgType: "markdown"} payload.Markdown.Title = title payload.Markdown.Text = markdown body, err := json.Marshal(payload) if err != nil { return fmt.Errorf("序列化Payload失败: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookURL, bytes.NewReader(body)) if err != nil { return fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json; charset=utf-8") resp, err := n.client.Do(req) if err != nil { return fmt.Errorf("发送请求失败: %w", err) } defer resp.Body.Close() var result dingTalkResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return fmt.Errorf("解析响应失败: %w", err) } if resp.StatusCode != http.StatusOK || result.ErrCode != 0 { return fmt.Errorf("钉钉返回错误: status=%d code=%d msg=%s", resp.StatusCode, result.ErrCode, result.ErrMsg) } return nil } func (n *dingTalkNotifier) webhookWithSignature() (string, error) { if strings.TrimSpace(n.secret) == "" { return n.webhook, nil } ts := time.Now().UnixMilli() strToSign := fmt.Sprintf("%d\n%s", ts, n.secret) mac := hmac.New(sha256.New, []byte(n.secret)) if _, err := mac.Write([]byte(strToSign)); err != nil { return "", fmt.Errorf("计算签名失败: %w", err) } signature := url.QueryEscape(base64.StdEncoding.EncodeToString(mac.Sum(nil))) params := fmt.Sprintf("timestamp=%d&sign=%s", ts, signature) return appendQuery(n.webhook, params), nil } func appendQuery(base, query string) string { if strings.Contains(base, "?") { if strings.HasSuffix(base, "?") || strings.HasSuffix(base, "&") { return base + query } return base + "&" + query } return base + "?" + query } func detectHost() string { if ip := getLocalIP(); ip != "" { return ip } if name, err := os.Hostname(); err == nil && strings.TrimSpace(name) != "" { return name } return "unknown-host" } func getLocalIP() string { addrs, err := net.InterfaceAddrs() if err != nil { return "" } for _, addr := range addrs { if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { if ip := ipnet.IP.To4(); ip != nil { return ip.String() } } } return "" } func maskWebhook(webhook string) string { trimmed := strings.TrimSpace(webhook) if trimmed == "" { return "" } if len(trimmed) <= 12 { return trimmed[:3] + "***" } return trimmed[:6] + "..." + trimmed[len(trimmed)-4:] } func valueOrFallback(value, fallback string) string { if strings.TrimSpace(value) == "" { return fallback } return strings.TrimSpace(value) }