feat():learning后台管理项目初始化
This commit is contained in:
476
pkg/middleware/sso_client.go
Normal file
476
pkg/middleware/sso_client.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"goalfymax-admin/internal/models"
|
||||
"goalfymax-admin/pkg/utils"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// SSOClient SSO客户端
|
||||
type SSOClient struct {
|
||||
config *models.SSOConfig
|
||||
http *http.Client
|
||||
oauth2 *oauth2.Config
|
||||
logger *utils.Logger
|
||||
}
|
||||
|
||||
// NewSSOClient 创建新的SSO客户端
|
||||
func NewSSOClient(config *models.SSOConfig, logger *utils.Logger) *SSOClient {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
}
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURI,
|
||||
Scopes: strings.Split(config.Scope, " "),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("%s/oauth2/authorize", config.SSOServerURL),
|
||||
TokenURL: fmt.Sprintf("%s/oauth2/token", config.SSOServerURL),
|
||||
},
|
||||
}
|
||||
|
||||
return &SSOClient{
|
||||
config: config,
|
||||
http: httpClient,
|
||||
oauth2: oauth2Config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SSOClient) GetServerUrl() string {
|
||||
return c.config.SSOServerURL
|
||||
}
|
||||
|
||||
func (c *SSOClient) GetRedirectUrl() string {
|
||||
return c.config.RedirectURI
|
||||
}
|
||||
|
||||
// GetAuthorizationURL 获取授权URL
|
||||
func (c *SSOClient) GetAuthorizationURL(state string) (string, string, error) {
|
||||
// 生成PKCE挑战
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
// 构建授权URL
|
||||
authURL := c.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
|
||||
|
||||
return authURL, codeVerifier, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken 使用授权码交换令牌
|
||||
func (c *SSOClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier string) (*models.TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("client_id", c.config.ClientID)
|
||||
data.Set("client_secret", c.config.ClientSecret)
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", c.config.RedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("token exchange failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
var tokenResp models.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新访问令牌
|
||||
func (c *SSOClient) RefreshToken(ctx context.Context, refreshToken string) (*models.TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("client_id", c.config.ClientID)
|
||||
data.Set("client_secret", c.config.ClientSecret)
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("token refresh failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
var tokenResp models.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *SSOClient) GetUserInfo(ctx context.Context, accessToken string) (*models.UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/userinfo", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("get user info failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
// 首先解析为 SSOUserInfo,它可以接受任何类型的 roles 字段
|
||||
var ssoUserInfo models.SSOUserInfo
|
||||
if err := json.Unmarshal(body, &ssoUserInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info response: %w", err)
|
||||
}
|
||||
|
||||
// 转换为应用程序内部使用的 UserInfo
|
||||
userInfo := &models.UserInfo{
|
||||
Sub: ssoUserInfo.Sub,
|
||||
Name: ssoUserInfo.Name,
|
||||
Email: ssoUserInfo.Email,
|
||||
PreferredUsername: ssoUserInfo.PreferredUsername,
|
||||
}
|
||||
|
||||
// 规范化 roles 字段:支持 string、[]string、[]interface{}
|
||||
if ssoUserInfo.Roles != nil {
|
||||
switch v := ssoUserInfo.Roles.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: v})
|
||||
}
|
||||
case []string:
|
||||
for _, name := range v {
|
||||
if name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, r := range v {
|
||||
if name, ok := r.(string); ok && name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
default:
|
||||
// 尝试通用 JSON 数组字符串
|
||||
if b, err := json.Marshal(v); err == nil {
|
||||
var arr []string
|
||||
if err := json.Unmarshal(b, &arr); err == nil {
|
||||
for _, name := range arr {
|
||||
if name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// ValidateTokenWithSSO 通过调用第三方SSO服务验证token
|
||||
func (c *SSOClient) ValidateTokenWithSSO(ctx context.Context, accessToken string) error {
|
||||
// 尝试调用用户信息接口来验证token
|
||||
_, err := c.GetUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌(保留原有方法,但标记为不推荐使用)
|
||||
func (c *SSOClient) ValidateToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
|
||||
// 获取JWKS
|
||||
jwks, err := c.GetJWKS(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
// 解析JWT令牌
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名算法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// 从JWKS中获取公钥
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("kid not found in token header")
|
||||
}
|
||||
|
||||
// 在JWKS中查找对应的公钥
|
||||
publicKey, err := c.findPublicKey(jwks, kid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// 验证令牌是否有效
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// 验证令牌的声明
|
||||
if err := c.validateTokenClaims(token); err != nil {
|
||||
return nil, fmt.Errorf("invalid token claims: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// findPublicKey 在JWKS中查找对应的公钥
|
||||
func (c *SSOClient) findPublicKey(jwks *models.JWKS, kid string) (interface{}, error) {
|
||||
for _, key := range jwks.Keys {
|
||||
if keyKid, ok := key["kid"].(string); ok && keyKid == kid {
|
||||
// 解析RSA公钥
|
||||
if n, ok := key["n"].(string); ok {
|
||||
if e, ok := key["e"].(string); ok {
|
||||
return c.parseRSAPublicKey(n, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("public key not found for kid: %s", kid)
|
||||
}
|
||||
|
||||
// parseRSAPublicKey 解析RSA公钥
|
||||
func (c *SSOClient) parseRSAPublicKey(n, e string) (interface{}, error) {
|
||||
// 解码Base64URL编码的模数和指数
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode modulus: %w", err)
|
||||
}
|
||||
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(e)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode exponent: %w", err)
|
||||
}
|
||||
|
||||
// 创建RSA公钥
|
||||
modulus := new(big.Int).SetBytes(nBytes)
|
||||
exponent := new(big.Int).SetBytes(eBytes)
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
N: modulus,
|
||||
E: int(exponent.Int64()),
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// validateTokenClaims 验证令牌声明
|
||||
func (c *SSOClient) validateTokenClaims(token *jwt.Token) error {
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
// 验证发行者
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
if iss != c.config.SSOServerURL {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证受众
|
||||
if aud, ok := claims["aud"].(string); ok {
|
||||
if aud != c.config.ResourceAud {
|
||||
return fmt.Errorf("invalid audience: %s", aud)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Unix(int64(exp), 0).Before(time.Now()) {
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证生效时间
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if time.Unix(int64(nbf), 0).After(time.Now()) {
|
||||
return fmt.Errorf("token not yet valid")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOpenIDConfiguration 获取OpenID配置
|
||||
func (c *SSOClient) GetOpenIDConfiguration(ctx context.Context) (*models.OpenIDConfiguration, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/.well-known/openid-configuration", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get OpenID configuration: %s", resp.Status)
|
||||
}
|
||||
|
||||
var config models.OpenIDConfiguration
|
||||
if err := json.Unmarshal(body, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse OpenID configuration: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// GetJWKS 获取JWKS
|
||||
func (c *SSOClient) GetJWKS(ctx context.Context) (*models.JWKS, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/jwks.json", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get JWKS: %s", resp.Status)
|
||||
}
|
||||
|
||||
var jwks models.JWKS
|
||||
if err := json.Unmarshal(body, &jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWKS: %w", err)
|
||||
}
|
||||
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// Logout 登出
|
||||
func (c *SSOClient) Logout(ctx context.Context, accessToken string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/logout-pre", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to make request", zap.Error(err))
|
||||
return fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
c.logger.Error("failed to logout", zap.String("status", resp.Status), zap.String("body", string(body)))
|
||||
return fmt.Errorf("logout failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成PKCE代码验证器
|
||||
func generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
Reference in New Issue
Block a user