Files

477 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"goalfymax-admin/internal/models"
"goalfymax-admin/pkg/utils"
"io"
"math/big"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
// SSOClient SSO客户端
type SSOClient struct {
config *models.SSOConfig
http *http.Client
oauth2 *oauth2.Config
logger *utils.Logger
}
// NewSSOClient 创建新的SSO客户端
func NewSSOClient(config *models.SSOConfig, logger *utils.Logger) *SSOClient {
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
httpClient := &http.Client{
Timeout: config.Timeout,
}
oauth2Config := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURI,
Scopes: strings.Split(config.Scope, " "),
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth2/authorize", config.SSOServerURL),
TokenURL: fmt.Sprintf("%s/oauth2/token", config.SSOServerURL),
},
}
return &SSOClient{
config: config,
http: httpClient,
oauth2: oauth2Config,
logger: logger,
}
}
func (c *SSOClient) GetServerUrl() string {
return c.config.SSOServerURL
}
func (c *SSOClient) GetRedirectUrl() string {
return c.config.RedirectURI
}
// GetAuthorizationURL 获取授权URL
func (c *SSOClient) GetAuthorizationURL(state string) (string, string, error) {
// 生成PKCE挑战
codeVerifier, err := generateCodeVerifier()
if err != nil {
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
}
// 构建授权URL
authURL := c.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
return authURL, codeVerifier, nil
}
// ExchangeCodeForToken 使用授权码交换令牌
func (c *SSOClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier string) (*models.TokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", c.config.ClientID)
data.Set("client_secret", c.config.ClientSecret)
data.Set("code", code)
data.Set("redirect_uri", c.config.RedirectURI)
data.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("token exchange failed: %s - %s", errorResp.Error, errorResp.Message)
}
var tokenResp models.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// RefreshToken 刷新访问令牌
func (c *SSOClient) RefreshToken(ctx context.Context, refreshToken string) (*models.TokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("client_id", c.config.ClientID)
data.Set("client_secret", c.config.ClientSecret)
data.Set("refresh_token", refreshToken)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("token refresh failed: %s - %s", errorResp.Error, errorResp.Message)
}
var tokenResp models.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取用户信息
func (c *SSOClient) GetUserInfo(ctx context.Context, accessToken string) (*models.UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/userinfo", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("get user info failed: %s - %s", errorResp.Error, errorResp.Message)
}
// 首先解析为 SSOUserInfo它可以接受任何类型的 roles 字段
var ssoUserInfo models.SSOUserInfo
if err := json.Unmarshal(body, &ssoUserInfo); err != nil {
return nil, fmt.Errorf("failed to parse user info response: %w", err)
}
// 转换为应用程序内部使用的 UserInfo
userInfo := &models.UserInfo{
Sub: ssoUserInfo.Sub,
Name: ssoUserInfo.Name,
Email: ssoUserInfo.Email,
PreferredUsername: ssoUserInfo.PreferredUsername,
}
// 规范化 roles 字段:支持 string、[]string、[]interface{}
if ssoUserInfo.Roles != nil {
switch v := ssoUserInfo.Roles.(type) {
case string:
if v != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: v})
}
case []string:
for _, name := range v {
if name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
case []interface{}:
for _, r := range v {
if name, ok := r.(string); ok && name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
default:
// 尝试通用 JSON 数组字符串
if b, err := json.Marshal(v); err == nil {
var arr []string
if err := json.Unmarshal(b, &arr); err == nil {
for _, name := range arr {
if name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
}
}
}
}
return userInfo, nil
}
// ValidateTokenWithSSO 通过调用第三方SSO服务验证token
func (c *SSOClient) ValidateTokenWithSSO(ctx context.Context, accessToken string) error {
// 尝试调用用户信息接口来验证token
_, err := c.GetUserInfo(ctx, accessToken)
if err != nil {
return fmt.Errorf("token validation failed: %w", err)
}
return nil
}
// ValidateToken 验证令牌(保留原有方法,但标记为不推荐使用)
func (c *SSOClient) ValidateToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
// 获取JWKS
jwks, err := c.GetJWKS(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get JWKS: %w", err)
}
// 解析JWT令牌
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
// 验证签名算法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// 从JWKS中获取公钥
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("kid not found in token header")
}
// 在JWKS中查找对应的公钥
publicKey, err := c.findPublicKey(jwks, kid)
if err != nil {
return nil, fmt.Errorf("failed to find public key: %w", err)
}
return publicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// 验证令牌是否有效
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
// 验证令牌的声明
if err := c.validateTokenClaims(token); err != nil {
return nil, fmt.Errorf("invalid token claims: %w", err)
}
return token, nil
}
// findPublicKey 在JWKS中查找对应的公钥
func (c *SSOClient) findPublicKey(jwks *models.JWKS, kid string) (interface{}, error) {
for _, key := range jwks.Keys {
if keyKid, ok := key["kid"].(string); ok && keyKid == kid {
// 解析RSA公钥
if n, ok := key["n"].(string); ok {
if e, ok := key["e"].(string); ok {
return c.parseRSAPublicKey(n, e)
}
}
}
}
return nil, fmt.Errorf("public key not found for kid: %s", kid)
}
// parseRSAPublicKey 解析RSA公钥
func (c *SSOClient) parseRSAPublicKey(n, e string) (interface{}, error) {
// 解码Base64URL编码的模数和指数
nBytes, err := base64.RawURLEncoding.DecodeString(n)
if err != nil {
return nil, fmt.Errorf("failed to decode modulus: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(e)
if err != nil {
return nil, fmt.Errorf("failed to decode exponent: %w", err)
}
// 创建RSA公钥
modulus := new(big.Int).SetBytes(nBytes)
exponent := new(big.Int).SetBytes(eBytes)
publicKey := &rsa.PublicKey{
N: modulus,
E: int(exponent.Int64()),
}
return publicKey, nil
}
// validateTokenClaims 验证令牌声明
func (c *SSOClient) validateTokenClaims(token *jwt.Token) error {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return fmt.Errorf("invalid token claims")
}
// 验证发行者
if iss, ok := claims["iss"].(string); ok {
if iss != c.config.SSOServerURL {
return fmt.Errorf("invalid issuer: %s", iss)
}
}
// 验证受众
if aud, ok := claims["aud"].(string); ok {
if aud != c.config.ResourceAud {
return fmt.Errorf("invalid audience: %s", aud)
}
}
// 验证过期时间
if exp, ok := claims["exp"].(float64); ok {
if time.Unix(int64(exp), 0).Before(time.Now()) {
return fmt.Errorf("token expired")
}
}
// 验证生效时间
if nbf, ok := claims["nbf"].(float64); ok {
if time.Unix(int64(nbf), 0).After(time.Now()) {
return fmt.Errorf("token not yet valid")
}
}
return nil
}
// GetOpenIDConfiguration 获取OpenID配置
func (c *SSOClient) GetOpenIDConfiguration(ctx context.Context) (*models.OpenIDConfiguration, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/.well-known/openid-configuration", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get OpenID configuration: %s", resp.Status)
}
var config models.OpenIDConfiguration
if err := json.Unmarshal(body, &config); err != nil {
return nil, fmt.Errorf("failed to parse OpenID configuration: %w", err)
}
return &config, nil
}
// GetJWKS 获取JWKS
func (c *SSOClient) GetJWKS(ctx context.Context) (*models.JWKS, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/jwks.json", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get JWKS: %s", resp.Status)
}
var jwks models.JWKS
if err := json.Unmarshal(body, &jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS: %w", err)
}
return &jwks, nil
}
// Logout 登出
func (c *SSOClient) Logout(ctx context.Context, accessToken string) error {
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/logout-pre", c.config.SSOServerURL), nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.http.Do(req)
if err != nil {
c.logger.Error("failed to make request", zap.Error(err))
return fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
c.logger.Error("failed to logout", zap.String("status", resp.Status), zap.String("body", string(body)))
return fmt.Errorf("logout failed: %s - %s", resp.Status, string(body))
}
return nil
}
// 生成PKCE代码验证器
func generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}