package middleware import ( "context" "crypto/rand" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "goalfymax-admin/internal/models" "goalfymax-admin/pkg/utils" "io" "math/big" "net/http" "net/url" "strings" "time" "github.com/golang-jwt/jwt/v5" "go.uber.org/zap" "golang.org/x/oauth2" ) // SSOClient SSO客户端 type SSOClient struct { config *models.SSOConfig http *http.Client oauth2 *oauth2.Config logger *utils.Logger } // NewSSOClient 创建新的SSO客户端 func NewSSOClient(config *models.SSOConfig, logger *utils.Logger) *SSOClient { if config.Timeout == 0 { config.Timeout = 30 * time.Second } httpClient := &http.Client{ Timeout: config.Timeout, } oauth2Config := &oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURI, Scopes: strings.Split(config.Scope, " "), Endpoint: oauth2.Endpoint{ AuthURL: fmt.Sprintf("%s/oauth2/authorize", config.SSOServerURL), TokenURL: fmt.Sprintf("%s/oauth2/token", config.SSOServerURL), }, } return &SSOClient{ config: config, http: httpClient, oauth2: oauth2Config, logger: logger, } } func (c *SSOClient) GetServerUrl() string { return c.config.SSOServerURL } func (c *SSOClient) GetRedirectUrl() string { return c.config.RedirectURI } // GetAuthorizationURL 获取授权URL func (c *SSOClient) GetAuthorizationURL(state string) (string, string, error) { // 生成PKCE挑战 codeVerifier, err := generateCodeVerifier() if err != nil { return "", "", fmt.Errorf("failed to generate code verifier: %w", err) } // 构建授权URL authURL := c.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier)) return authURL, codeVerifier, nil } // ExchangeCodeForToken 使用授权码交换令牌 func (c *SSOClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier string) (*models.TokenResponse, error) { data := url.Values{} data.Set("grant_type", "authorization_code") data.Set("client_id", c.config.ClientID) data.Set("client_secret", c.config.ClientSecret) data.Set("code", code) data.Set("redirect_uri", c.config.RedirectURI) data.Set("code_verifier", codeVerifier) req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { var errorResp models.ErrorResponse if err := json.Unmarshal(body, &errorResp); err != nil { return nil, fmt.Errorf("failed to parse error response: %w", err) } return nil, fmt.Errorf("token exchange failed: %s - %s", errorResp.Error, errorResp.Message) } var tokenResp models.TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse token response: %w", err) } return &tokenResp, nil } // RefreshToken 刷新访问令牌 func (c *SSOClient) RefreshToken(ctx context.Context, refreshToken string) (*models.TokenResponse, error) { data := url.Values{} data.Set("grant_type", "refresh_token") data.Set("client_id", c.config.ClientID) data.Set("client_secret", c.config.ClientSecret) data.Set("refresh_token", refreshToken) req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { var errorResp models.ErrorResponse if err := json.Unmarshal(body, &errorResp); err != nil { return nil, fmt.Errorf("failed to parse error response: %w", err) } return nil, fmt.Errorf("token refresh failed: %s - %s", errorResp.Error, errorResp.Message) } var tokenResp models.TokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("failed to parse token response: %w", err) } return &tokenResp, nil } // GetUserInfo 获取用户信息 func (c *SSOClient) GetUserInfo(ctx context.Context, accessToken string) (*models.UserInfo, error) { req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/userinfo", c.config.SSOServerURL), nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { var errorResp models.ErrorResponse if err := json.Unmarshal(body, &errorResp); err != nil { return nil, fmt.Errorf("failed to parse error response: %w", err) } return nil, fmt.Errorf("get user info failed: %s - %s", errorResp.Error, errorResp.Message) } // 首先解析为 SSOUserInfo,它可以接受任何类型的 roles 字段 var ssoUserInfo models.SSOUserInfo if err := json.Unmarshal(body, &ssoUserInfo); err != nil { return nil, fmt.Errorf("failed to parse user info response: %w", err) } // 转换为应用程序内部使用的 UserInfo userInfo := &models.UserInfo{ Sub: ssoUserInfo.Sub, Name: ssoUserInfo.Name, Email: ssoUserInfo.Email, PreferredUsername: ssoUserInfo.PreferredUsername, } // 规范化 roles 字段:支持 string、[]string、[]interface{} if ssoUserInfo.Roles != nil { switch v := ssoUserInfo.Roles.(type) { case string: if v != "" { userInfo.Roles = append(userInfo.Roles, models.Role{Name: v}) } case []string: for _, name := range v { if name != "" { userInfo.Roles = append(userInfo.Roles, models.Role{Name: name}) } } case []interface{}: for _, r := range v { if name, ok := r.(string); ok && name != "" { userInfo.Roles = append(userInfo.Roles, models.Role{Name: name}) } } default: // 尝试通用 JSON 数组字符串 if b, err := json.Marshal(v); err == nil { var arr []string if err := json.Unmarshal(b, &arr); err == nil { for _, name := range arr { if name != "" { userInfo.Roles = append(userInfo.Roles, models.Role{Name: name}) } } } } } } return userInfo, nil } // ValidateTokenWithSSO 通过调用第三方SSO服务验证token func (c *SSOClient) ValidateTokenWithSSO(ctx context.Context, accessToken string) error { // 尝试调用用户信息接口来验证token _, err := c.GetUserInfo(ctx, accessToken) if err != nil { return fmt.Errorf("token validation failed: %w", err) } return nil } // ValidateToken 验证令牌(保留原有方法,但标记为不推荐使用) func (c *SSOClient) ValidateToken(ctx context.Context, accessToken string) (*jwt.Token, error) { // 获取JWKS jwks, err := c.GetJWKS(ctx) if err != nil { return nil, fmt.Errorf("failed to get JWKS: %w", err) } // 解析JWT令牌 token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { // 验证签名算法 if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } // 从JWKS中获取公钥 kid, ok := token.Header["kid"].(string) if !ok { return nil, fmt.Errorf("kid not found in token header") } // 在JWKS中查找对应的公钥 publicKey, err := c.findPublicKey(jwks, kid) if err != nil { return nil, fmt.Errorf("failed to find public key: %w", err) } return publicKey, nil }) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } // 验证令牌是否有效 if !token.Valid { return nil, fmt.Errorf("invalid token") } // 验证令牌的声明 if err := c.validateTokenClaims(token); err != nil { return nil, fmt.Errorf("invalid token claims: %w", err) } return token, nil } // findPublicKey 在JWKS中查找对应的公钥 func (c *SSOClient) findPublicKey(jwks *models.JWKS, kid string) (interface{}, error) { for _, key := range jwks.Keys { if keyKid, ok := key["kid"].(string); ok && keyKid == kid { // 解析RSA公钥 if n, ok := key["n"].(string); ok { if e, ok := key["e"].(string); ok { return c.parseRSAPublicKey(n, e) } } } } return nil, fmt.Errorf("public key not found for kid: %s", kid) } // parseRSAPublicKey 解析RSA公钥 func (c *SSOClient) parseRSAPublicKey(n, e string) (interface{}, error) { // 解码Base64URL编码的模数和指数 nBytes, err := base64.RawURLEncoding.DecodeString(n) if err != nil { return nil, fmt.Errorf("failed to decode modulus: %w", err) } eBytes, err := base64.RawURLEncoding.DecodeString(e) if err != nil { return nil, fmt.Errorf("failed to decode exponent: %w", err) } // 创建RSA公钥 modulus := new(big.Int).SetBytes(nBytes) exponent := new(big.Int).SetBytes(eBytes) publicKey := &rsa.PublicKey{ N: modulus, E: int(exponent.Int64()), } return publicKey, nil } // validateTokenClaims 验证令牌声明 func (c *SSOClient) validateTokenClaims(token *jwt.Token) error { claims, ok := token.Claims.(jwt.MapClaims) if !ok { return fmt.Errorf("invalid token claims") } // 验证发行者 if iss, ok := claims["iss"].(string); ok { if iss != c.config.SSOServerURL { return fmt.Errorf("invalid issuer: %s", iss) } } // 验证受众 if aud, ok := claims["aud"].(string); ok { if aud != c.config.ResourceAud { return fmt.Errorf("invalid audience: %s", aud) } } // 验证过期时间 if exp, ok := claims["exp"].(float64); ok { if time.Unix(int64(exp), 0).Before(time.Now()) { return fmt.Errorf("token expired") } } // 验证生效时间 if nbf, ok := claims["nbf"].(float64); ok { if time.Unix(int64(nbf), 0).After(time.Now()) { return fmt.Errorf("token not yet valid") } } return nil } // GetOpenIDConfiguration 获取OpenID配置 func (c *SSOClient) GetOpenIDConfiguration(ctx context.Context) (*models.OpenIDConfiguration, error) { req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/.well-known/openid-configuration", c.config.SSOServerURL), nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get OpenID configuration: %s", resp.Status) } var config models.OpenIDConfiguration if err := json.Unmarshal(body, &config); err != nil { return nil, fmt.Errorf("failed to parse OpenID configuration: %w", err) } return &config, nil } // GetJWKS 获取JWKS func (c *SSOClient) GetJWKS(ctx context.Context) (*models.JWKS, error) { req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/jwks.json", c.config.SSOServerURL), nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get JWKS: %s", resp.Status) } var jwks models.JWKS if err := json.Unmarshal(body, &jwks); err != nil { return nil, fmt.Errorf("failed to parse JWKS: %w", err) } return &jwks, nil } // Logout 登出 func (c *SSOClient) Logout(ctx context.Context, accessToken string) error { req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/logout-pre", c.config.SSOServerURL), nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.http.Do(req) if err != nil { c.logger.Error("failed to make request", zap.Error(err)) return fmt.Errorf("failed to make request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) c.logger.Error("failed to logout", zap.String("status", resp.Status), zap.String("body", string(body))) return fmt.Errorf("logout failed: %s - %s", resp.Status, string(body)) } return nil } // 生成PKCE代码验证器 func generateCodeVerifier() (string, error) { bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(bytes), nil }