328 lines
8.2 KiB
Go
328 lines
8.2 KiB
Go
package notifier
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"text/template"
|
||
"time"
|
||
|
||
"goalfymax-admin/internal/config"
|
||
"goalfymax-admin/pkg/utils"
|
||
|
||
"github.com/shopspring/decimal"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
type accountType string
|
||
|
||
const (
|
||
accountTypeMCP accountType = "MCP账号"
|
||
accountTypeModel accountType = "模型账号"
|
||
)
|
||
|
||
var (
|
||
notifierMu sync.RWMutex
|
||
dingTalkService *dingTalkNotifier
|
||
)
|
||
|
||
// Init 初始化通知器
|
||
func Init(alertCfg config.AlertConfig, env string, logger *utils.Logger) {
|
||
notifierMu.Lock()
|
||
defer notifierMu.Unlock()
|
||
|
||
if alertCfg.DingTalk.TimeoutSeconds <= 0 {
|
||
alertCfg.DingTalk.TimeoutSeconds = 5
|
||
}
|
||
|
||
if !alertCfg.DingTalk.Enabled {
|
||
logger.Info("DingTalk notifier disabled")
|
||
dingTalkService = nil
|
||
return
|
||
}
|
||
|
||
service := newDingTalkNotifier(alertCfg.DingTalk, env, logger)
|
||
dingTalkService = service
|
||
logger.Info("DingTalk notifier initialized",
|
||
zap.String("env", env),
|
||
zap.String("webhook_hint", maskWebhook(alertCfg.DingTalk.Webhook)))
|
||
}
|
||
|
||
// NotifyMcpLowBalance MCP账号余额不足通知
|
||
func NotifyMcpLowBalance(provider, account string, balance, threshold decimal.Decimal) {
|
||
notifyLowBalance(accountTypeMCP, provider, account, "", balance, threshold)
|
||
}
|
||
|
||
// NotifyModelLowBalance 模型账号余额不足通知
|
||
func NotifyModelLowBalance(provider, account, model string, balance, threshold decimal.Decimal) {
|
||
notifyLowBalance(accountTypeModel, provider, account, model, balance, threshold)
|
||
}
|
||
|
||
func notifyLowBalance(accType accountType, provider, account, model string, balance, threshold decimal.Decimal) {
|
||
notifierMu.RLock()
|
||
service := dingTalkService
|
||
notifierMu.RUnlock()
|
||
|
||
if service == nil {
|
||
return
|
||
}
|
||
|
||
if err := service.sendLowBalanceAlert(accType, provider, account, model, balance, threshold); err != nil {
|
||
service.logger.Error("发送钉钉余额告警失败", zap.Error(err),
|
||
zap.String("account_type", string(accType)),
|
||
zap.String("provider", provider),
|
||
zap.String("account", account),
|
||
zap.String("model", model))
|
||
}
|
||
}
|
||
|
||
type dingTalkNotifier struct {
|
||
webhook string
|
||
secret string
|
||
env string
|
||
host string
|
||
client *http.Client
|
||
keyword string
|
||
logger *utils.Logger
|
||
tpl *template.Template
|
||
}
|
||
|
||
type dingTalkMarkdownPayload struct {
|
||
MsgType string `json:"msgtype"`
|
||
Markdown struct {
|
||
Title string `json:"title"`
|
||
Text string `json:"text"`
|
||
} `json:"markdown"`
|
||
}
|
||
|
||
type dingTalkResponse struct {
|
||
ErrCode int `json:"errcode"`
|
||
ErrMsg string `json:"errmsg"`
|
||
}
|
||
|
||
const lowBalanceMarkdownTemplate = `[自动生成]
|
||
|
||
{{.Emoji}} **{{.DisplayTitle}}**
|
||
> **级别**: {{.SeverityLabel}}
|
||
> **环境**: {{.Environment}}
|
||
> **主机**: {{.Host}}
|
||
>
|
||
> **服务**: {{.Service}}
|
||
|
||
- **时间**:{{.Timestamp}}
|
||
- **Provider**:{{.Provider}}
|
||
- **账号**:{{.Account}}
|
||
{{- if .Model }}
|
||
- **模型**:{{.Model}}
|
||
{{- end }}
|
||
- **当前余额**:${{.Balance}}
|
||
- **阈值**:${{.Threshold}}
|
||
`
|
||
|
||
type lowBalanceTemplateData struct {
|
||
Emoji string
|
||
DisplayTitle string
|
||
SeverityLabel string
|
||
Environment string
|
||
Host string
|
||
Service string
|
||
Provider string
|
||
Account string
|
||
Model string
|
||
Balance string
|
||
Threshold string
|
||
Timestamp string
|
||
}
|
||
|
||
type severityMetaEntry struct {
|
||
Emoji string
|
||
Label string
|
||
}
|
||
|
||
var (
|
||
severityCritical = severityMetaEntry{Emoji: "🚨", Label: "严重告警"}
|
||
defaultService = "balance-monitor"
|
||
)
|
||
|
||
func newDingTalkNotifier(cfg config.DingTalkConfig, env string, logger *utils.Logger) *dingTalkNotifier {
|
||
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = 5 * time.Second
|
||
}
|
||
|
||
return &dingTalkNotifier{
|
||
webhook: strings.TrimSpace(cfg.Webhook),
|
||
secret: strings.TrimSpace(cfg.Secret),
|
||
env: strings.TrimSpace(env),
|
||
host: detectHost(),
|
||
client: &http.Client{
|
||
Timeout: timeout,
|
||
},
|
||
keyword: strings.TrimSpace(cfg.Keyword),
|
||
logger: logger,
|
||
tpl: template.Must(template.New("lowBalance").Parse(lowBalanceMarkdownTemplate)),
|
||
}
|
||
}
|
||
|
||
func (n *dingTalkNotifier) sendLowBalanceAlert(accType accountType, provider, account, model string, balance, threshold decimal.Decimal) error {
|
||
if strings.TrimSpace(n.webhook) == "" {
|
||
return fmt.Errorf("dingtalk webhook 未配置")
|
||
}
|
||
|
||
accountLabel := strings.TrimSpace(account)
|
||
if accountLabel == "" {
|
||
accountLabel = "(空)"
|
||
}
|
||
|
||
title := fmt.Sprintf("%s余额告警", accType)
|
||
displayTitle := title
|
||
|
||
if keyword := strings.TrimSpace(n.keyword); keyword != "" {
|
||
if !strings.Contains(title, keyword) {
|
||
title = fmt.Sprintf("%s %s", keyword, title)
|
||
}
|
||
if !strings.Contains(displayTitle, keyword) {
|
||
displayTitle = fmt.Sprintf("%s %s", keyword, displayTitle)
|
||
}
|
||
}
|
||
|
||
data := lowBalanceTemplateData{
|
||
Emoji: severityCritical.Emoji,
|
||
DisplayTitle: displayTitle,
|
||
SeverityLabel: severityCritical.Label,
|
||
Environment: valueOrFallback(n.env, "unknown"),
|
||
Host: valueOrFallback(n.host, "unknown-host"),
|
||
Service: defaultService,
|
||
Provider: valueOrFallback(provider, "unknown"),
|
||
Account: accountLabel,
|
||
Model: strings.TrimSpace(model),
|
||
Balance: balance.StringFixed(4),
|
||
Threshold: threshold.StringFixed(4),
|
||
Timestamp: time.Now().Format("2006-01-02 15:04:05"),
|
||
}
|
||
|
||
var buf bytes.Buffer
|
||
if err := n.tpl.Execute(&buf, data); err != nil {
|
||
return fmt.Errorf("渲染模板失败: %w", err)
|
||
}
|
||
|
||
return n.sendMarkdown(context.Background(), title, buf.String())
|
||
}
|
||
|
||
func (n *dingTalkNotifier) sendMarkdown(ctx context.Context, title, markdown string) error {
|
||
webhookURL, err := n.webhookWithSignature()
|
||
if err != nil {
|
||
return fmt.Errorf("生成签名失败: %w", err)
|
||
}
|
||
|
||
payload := dingTalkMarkdownPayload{MsgType: "markdown"}
|
||
payload.Markdown.Title = title
|
||
payload.Markdown.Text = markdown
|
||
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化Payload失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
return fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||
|
||
resp, err := n.client.Do(req)
|
||
if err != nil {
|
||
return fmt.Errorf("发送请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
var result dingTalkResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
if resp.StatusCode != http.StatusOK || result.ErrCode != 0 {
|
||
return fmt.Errorf("钉钉返回错误: status=%d code=%d msg=%s", resp.StatusCode, result.ErrCode, result.ErrMsg)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (n *dingTalkNotifier) webhookWithSignature() (string, error) {
|
||
if strings.TrimSpace(n.secret) == "" {
|
||
return n.webhook, nil
|
||
}
|
||
|
||
ts := time.Now().UnixMilli()
|
||
strToSign := fmt.Sprintf("%d\n%s", ts, n.secret)
|
||
mac := hmac.New(sha256.New, []byte(n.secret))
|
||
if _, err := mac.Write([]byte(strToSign)); err != nil {
|
||
return "", fmt.Errorf("计算签名失败: %w", err)
|
||
}
|
||
signature := url.QueryEscape(base64.StdEncoding.EncodeToString(mac.Sum(nil)))
|
||
params := fmt.Sprintf("timestamp=%d&sign=%s", ts, signature)
|
||
return appendQuery(n.webhook, params), nil
|
||
}
|
||
|
||
func appendQuery(base, query string) string {
|
||
if strings.Contains(base, "?") {
|
||
if strings.HasSuffix(base, "?") || strings.HasSuffix(base, "&") {
|
||
return base + query
|
||
}
|
||
return base + "&" + query
|
||
}
|
||
return base + "?" + query
|
||
}
|
||
|
||
func detectHost() string {
|
||
if ip := getLocalIP(); ip != "" {
|
||
return ip
|
||
}
|
||
if name, err := os.Hostname(); err == nil && strings.TrimSpace(name) != "" {
|
||
return name
|
||
}
|
||
return "unknown-host"
|
||
}
|
||
|
||
func getLocalIP() string {
|
||
addrs, err := net.InterfaceAddrs()
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
for _, addr := range addrs {
|
||
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
|
||
if ip := ipnet.IP.To4(); ip != nil {
|
||
return ip.String()
|
||
}
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func maskWebhook(webhook string) string {
|
||
trimmed := strings.TrimSpace(webhook)
|
||
if trimmed == "" {
|
||
return ""
|
||
}
|
||
if len(trimmed) <= 12 {
|
||
return trimmed[:3] + "***"
|
||
}
|
||
return trimmed[:6] + "..." + trimmed[len(trimmed)-4:]
|
||
}
|
||
|
||
func valueOrFallback(value, fallback string) string {
|
||
if strings.TrimSpace(value) == "" {
|
||
return fallback
|
||
}
|
||
return strings.TrimSpace(value)
|
||
}
|