feat():learning后台管理项目初始化

This commit is contained in:
yuj
2025-12-04 16:23:46 +08:00
parent 39886d50d2
commit 88e048f4d1
154 changed files with 28966 additions and 6 deletions

290
pkg/middleware/auth.go Normal file
View File

@@ -0,0 +1,290 @@
package middleware
import (
"context"
"fmt"
"goalfymax-admin/internal/models"
"goalfymax-admin/internal/storage"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// SessionManager 会话管理器接口
type SessionManager interface {
GetSession(ctx context.Context, sessionID string) (*models.Session, error)
SetSession(ctx context.Context, sessionID string, session *models.Session) error
DeleteSession(ctx context.Context, sessionID string) error
}
// MemorySessionManager 内存会话管理器
type MemorySessionManager struct {
sessions map[string]*models.Session
}
// NewMemorySessionManager 创建内存会话管理器
func NewMemorySessionManager() *MemorySessionManager {
return &MemorySessionManager{
sessions: make(map[string]*models.Session),
}
}
func (m *MemorySessionManager) GetSession(ctx context.Context, sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, fmt.Errorf("session not found")
}
if time.Now().After(session.ExpiresAt) {
delete(m.sessions, sessionID)
return nil, fmt.Errorf("session expired")
}
return session, nil
}
func (m *MemorySessionManager) SetSession(ctx context.Context, sessionID string, session *models.Session) error {
m.sessions[sessionID] = session
return nil
}
func (m *MemorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
delete(m.sessions, sessionID)
return nil
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
client *SSOClient
sessionManager SessionManager
loginURL string
validationMode string // "sso" 或 "jwt"
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(client *SSOClient, sessionManager SessionManager, loginURL string) *AuthMiddleware {
return &AuthMiddleware{
client: client,
sessionManager: sessionManager,
loginURL: loginURL,
validationMode: "sso", // 默认使用SSO验证模式
}
}
// SetValidationMode 设置验证模式
func (m *AuthMiddleware) SetValidationMode(mode string) {
m.validationMode = mode
}
// RequireAuth 要求认证的中间件
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization头获取访问令牌
var token string = ""
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
authHeader = c.Query("token")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "Authorization header is required",
})
c.Abort()
return
}
token = strings.Trim(authHeader, " ")
} else {
// 提取Bearer令牌
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "Invalid authorization header",
})
c.Abort()
return
}
// 直接调用第三方SSO服务验证token并获取用户信息
// 不再进行本地JWT验证而是通过调用第三方API来验证token的有效性
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": fmt.Sprintf("Invalid token or failed to get user info: %v", err),
})
c.Abort()
return
}
// 解析SSO用户ID为数值
userID, err := strconv.ParseUint(userInfo.Sub, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_user_id",
"message": "Invalid user ID in token",
})
c.Abort()
return
}
userIDUint := uint(userID)
// 查找或创建本地用户
user, err := m.findOrCreateUser(userIDUint, userInfo)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "user_creation_failed",
"message": "Failed to create or find user",
})
c.Abort()
return
}
// 将用户信息添加到上下文
c.Set("user", userInfo)
c.Set("user_id", userIDUint) // 使用转换后的用户ID
c.Set("local_user", user) // 本地用户对象
c.Set("token", token)
c.Next()
}
}
// OptionalAuth 可选认证的中间件
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization头获取访问令牌
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
// 提取Bearer令牌
token := ""
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
if token == "" {
c.Next()
return
}
// 直接调用第三方SSO服务验证token并获取用户信息
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
if err != nil {
c.Next()
return
}
// 将用户信息添加到上下文
c.Set("user", userInfo)
c.Set("token", token)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户信息
func GetUserFromContext(c *gin.Context) (*models.UserInfo, bool) {
user, exists := c.Get("user")
if !exists {
return nil, false
}
userInfo, ok := user.(*models.UserInfo)
return userInfo, ok
}
// GetUserIDFromContext 从上下文中获取用户ID
func GetUserIDFromContext(c *gin.Context) (int, bool) {
userID, exists := c.Get("user_id")
if !exists {
return 0, false
}
userIDInt, ok := userID.(int)
return userIDInt, ok
}
// GetTokenFromContext 从上下文中获取令牌
func GetTokenFromContext(c *gin.Context) (string, bool) {
token, exists := c.Get("token")
if !exists {
return "", false
}
tokenStr, ok := token.(string)
return tokenStr, ok
}
// findOrCreateUser 查找或创建用户
func (m *AuthMiddleware) findOrCreateUser(userID uint, userInfo *models.UserInfo) (*models.User, error) {
// 尝试查找现有用户
var user models.User
err := storage.DB.Where("id = ?", userID).First(&user).Error
if err == nil {
// 用户存在,更新登录信息
now := time.Now()
user.LastLoginAt = &now
user.LoginCount++
// 更新用户信息如果SSO信息有变化
if userInfo.Name != "" && user.Nickname != userInfo.Name {
user.Nickname = userInfo.Name
}
if userInfo.Email != "" && user.Email != userInfo.Email {
user.Email = userInfo.Email
}
err = storage.DB.Save(&user).Error
if err != nil {
return nil, fmt.Errorf("更新用户信息失败: %w", err)
}
return &user, nil
}
if err != gorm.ErrRecordNotFound {
return nil, fmt.Errorf("查询用户失败: %w", err)
}
// 用户不存在,创建新用户
now := time.Now()
user = models.User{
BaseModel: models.BaseModel{
ID: userID,
CreatedAt: now,
UpdatedAt: now,
},
Username: userInfo.PreferredUsername,
Email: userInfo.Email,
Nickname: userInfo.Name,
Status: 1, // 默认启用
SSOProvider: "default", // 可以根据实际情况设置
LastLoginAt: &now,
LoginCount: 1,
}
// 如果PreferredUsername为空使用Email作为用户名
if user.Username == "" {
user.Username = userInfo.Email
}
err = storage.DB.Create(&user).Error
if err != nil {
return nil, fmt.Errorf("创建用户失败: %w", err)
}
return &user, nil
}

193
pkg/middleware/rbac.go Normal file
View File

@@ -0,0 +1,193 @@
package middleware
import (
"goalfymax-admin/internal/models"
"goalfymax-admin/internal/storage"
"goalfymax-admin/pkg/utils"
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
// RBACMiddleware 简化的RBAC权限中间件
type RBACMiddleware struct {
rbacStorage storage.RBACStorage
db *gorm.DB
logger *utils.Logger
}
// NewRBACMiddleware 创建RBAC中间件
func NewRBACMiddleware(rbacStorage storage.RBACStorage, db *gorm.DB, logger *utils.Logger) *RBACMiddleware {
return &RBACMiddleware{
rbacStorage: rbacStorage,
db: db,
logger: logger,
}
}
// RequirePagePermission 检查页面权限
func (m *RBACMiddleware) RequirePagePermission(pagePath string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
m.logger.Error("未找到用户ID", zap.String("pagePath", pagePath))
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
c.Abort()
return
}
// 类型转换
userIDUint, ok := userID.(uint)
if !ok {
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
c.Abort()
return
}
// 检查用户是否有页面访问权限(基于角色)
hasPermission, err := m.rbacStorage.CheckUserRolePagePermission(userIDUint, pagePath)
if err != nil {
m.logger.Error("页面权限检查失败", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "权限检查失败"})
c.Abort()
return
}
if !hasPermission {
m.logger.Warn("页面权限不足", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
c.JSON(http.StatusForbidden, gin.H{"error": "页面权限不足"})
c.Abort()
return
}
m.logger.Info("页面权限检查通过", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
c.Next()
}
}
// RequireRole 检查角色
func (m *RBACMiddleware) RequireRole(roleName string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
m.logger.Error("未找到用户ID", zap.String("role", roleName))
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
c.Abort()
return
}
// 类型转换userID 是 uint
userIDUint, ok := userID.(uint)
if !ok {
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
c.Abort()
return
}
// 直接查询用户角色
var user models.User
err := m.db.Where("id = ?", userIDUint).First(&user).Error
if err != nil {
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
c.Abort()
return
}
// 查询角色信息
var role models.Role
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
if err != nil {
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
c.Abort()
return
}
// 检查用户是否有指定角色
hasRole := role.Name == roleName
if !hasRole {
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.String("role", roleName))
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
c.Abort()
return
}
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.String("role", roleName))
c.Next()
}
}
// RequireAnyRole 检查任意角色
func (m *RBACMiddleware) RequireAnyRole(roleNames []string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
m.logger.Error("未找到用户ID", zap.Strings("roles", roleNames))
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
c.Abort()
return
}
// 类型转换userID 是 uint
userIDUint, ok := userID.(uint)
if !ok {
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
c.Abort()
return
}
// 直接查询用户角色
var user models.User
err := m.db.Where("id = ?", userIDUint).First(&user).Error
if err != nil {
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
c.Abort()
return
}
// 查询角色信息
var role models.Role
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
if err != nil {
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
c.Abort()
return
}
// 检查用户是否有任意一个角色
hasAnyRole := false
for _, roleName := range roleNames {
if role.Name == roleName {
hasAnyRole = true
break
}
}
if !hasAnyRole {
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
c.Abort()
return
}
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
c.Next()
}
}
// GetUserAccessiblePages 获取用户可访问的页面列表
func (m *RBACMiddleware) GetUserAccessiblePages(userID uint) ([]string, error) {
return m.rbacStorage.GetUserRoleAccessiblePages(userID)
}

View File

@@ -0,0 +1,476 @@
package middleware
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"goalfymax-admin/internal/models"
"goalfymax-admin/pkg/utils"
"io"
"math/big"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
// SSOClient SSO客户端
type SSOClient struct {
config *models.SSOConfig
http *http.Client
oauth2 *oauth2.Config
logger *utils.Logger
}
// NewSSOClient 创建新的SSO客户端
func NewSSOClient(config *models.SSOConfig, logger *utils.Logger) *SSOClient {
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
httpClient := &http.Client{
Timeout: config.Timeout,
}
oauth2Config := &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURI,
Scopes: strings.Split(config.Scope, " "),
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth2/authorize", config.SSOServerURL),
TokenURL: fmt.Sprintf("%s/oauth2/token", config.SSOServerURL),
},
}
return &SSOClient{
config: config,
http: httpClient,
oauth2: oauth2Config,
logger: logger,
}
}
func (c *SSOClient) GetServerUrl() string {
return c.config.SSOServerURL
}
func (c *SSOClient) GetRedirectUrl() string {
return c.config.RedirectURI
}
// GetAuthorizationURL 获取授权URL
func (c *SSOClient) GetAuthorizationURL(state string) (string, string, error) {
// 生成PKCE挑战
codeVerifier, err := generateCodeVerifier()
if err != nil {
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
}
// 构建授权URL
authURL := c.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
return authURL, codeVerifier, nil
}
// ExchangeCodeForToken 使用授权码交换令牌
func (c *SSOClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier string) (*models.TokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", c.config.ClientID)
data.Set("client_secret", c.config.ClientSecret)
data.Set("code", code)
data.Set("redirect_uri", c.config.RedirectURI)
data.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("token exchange failed: %s - %s", errorResp.Error, errorResp.Message)
}
var tokenResp models.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// RefreshToken 刷新访问令牌
func (c *SSOClient) RefreshToken(ctx context.Context, refreshToken string) (*models.TokenResponse, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("client_id", c.config.ClientID)
data.Set("client_secret", c.config.ClientSecret)
data.Set("refresh_token", refreshToken)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("token refresh failed: %s - %s", errorResp.Error, errorResp.Message)
}
var tokenResp models.TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取用户信息
func (c *SSOClient) GetUserInfo(ctx context.Context, accessToken string) (*models.UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/userinfo", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorResp models.ErrorResponse
if err := json.Unmarshal(body, &errorResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return nil, fmt.Errorf("get user info failed: %s - %s", errorResp.Error, errorResp.Message)
}
// 首先解析为 SSOUserInfo它可以接受任何类型的 roles 字段
var ssoUserInfo models.SSOUserInfo
if err := json.Unmarshal(body, &ssoUserInfo); err != nil {
return nil, fmt.Errorf("failed to parse user info response: %w", err)
}
// 转换为应用程序内部使用的 UserInfo
userInfo := &models.UserInfo{
Sub: ssoUserInfo.Sub,
Name: ssoUserInfo.Name,
Email: ssoUserInfo.Email,
PreferredUsername: ssoUserInfo.PreferredUsername,
}
// 规范化 roles 字段:支持 string、[]string、[]interface{}
if ssoUserInfo.Roles != nil {
switch v := ssoUserInfo.Roles.(type) {
case string:
if v != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: v})
}
case []string:
for _, name := range v {
if name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
case []interface{}:
for _, r := range v {
if name, ok := r.(string); ok && name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
default:
// 尝试通用 JSON 数组字符串
if b, err := json.Marshal(v); err == nil {
var arr []string
if err := json.Unmarshal(b, &arr); err == nil {
for _, name := range arr {
if name != "" {
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
}
}
}
}
}
}
return userInfo, nil
}
// ValidateTokenWithSSO 通过调用第三方SSO服务验证token
func (c *SSOClient) ValidateTokenWithSSO(ctx context.Context, accessToken string) error {
// 尝试调用用户信息接口来验证token
_, err := c.GetUserInfo(ctx, accessToken)
if err != nil {
return fmt.Errorf("token validation failed: %w", err)
}
return nil
}
// ValidateToken 验证令牌(保留原有方法,但标记为不推荐使用)
func (c *SSOClient) ValidateToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
// 获取JWKS
jwks, err := c.GetJWKS(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get JWKS: %w", err)
}
// 解析JWT令牌
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
// 验证签名算法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// 从JWKS中获取公钥
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("kid not found in token header")
}
// 在JWKS中查找对应的公钥
publicKey, err := c.findPublicKey(jwks, kid)
if err != nil {
return nil, fmt.Errorf("failed to find public key: %w", err)
}
return publicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// 验证令牌是否有效
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
// 验证令牌的声明
if err := c.validateTokenClaims(token); err != nil {
return nil, fmt.Errorf("invalid token claims: %w", err)
}
return token, nil
}
// findPublicKey 在JWKS中查找对应的公钥
func (c *SSOClient) findPublicKey(jwks *models.JWKS, kid string) (interface{}, error) {
for _, key := range jwks.Keys {
if keyKid, ok := key["kid"].(string); ok && keyKid == kid {
// 解析RSA公钥
if n, ok := key["n"].(string); ok {
if e, ok := key["e"].(string); ok {
return c.parseRSAPublicKey(n, e)
}
}
}
}
return nil, fmt.Errorf("public key not found for kid: %s", kid)
}
// parseRSAPublicKey 解析RSA公钥
func (c *SSOClient) parseRSAPublicKey(n, e string) (interface{}, error) {
// 解码Base64URL编码的模数和指数
nBytes, err := base64.RawURLEncoding.DecodeString(n)
if err != nil {
return nil, fmt.Errorf("failed to decode modulus: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(e)
if err != nil {
return nil, fmt.Errorf("failed to decode exponent: %w", err)
}
// 创建RSA公钥
modulus := new(big.Int).SetBytes(nBytes)
exponent := new(big.Int).SetBytes(eBytes)
publicKey := &rsa.PublicKey{
N: modulus,
E: int(exponent.Int64()),
}
return publicKey, nil
}
// validateTokenClaims 验证令牌声明
func (c *SSOClient) validateTokenClaims(token *jwt.Token) error {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return fmt.Errorf("invalid token claims")
}
// 验证发行者
if iss, ok := claims["iss"].(string); ok {
if iss != c.config.SSOServerURL {
return fmt.Errorf("invalid issuer: %s", iss)
}
}
// 验证受众
if aud, ok := claims["aud"].(string); ok {
if aud != c.config.ResourceAud {
return fmt.Errorf("invalid audience: %s", aud)
}
}
// 验证过期时间
if exp, ok := claims["exp"].(float64); ok {
if time.Unix(int64(exp), 0).Before(time.Now()) {
return fmt.Errorf("token expired")
}
}
// 验证生效时间
if nbf, ok := claims["nbf"].(float64); ok {
if time.Unix(int64(nbf), 0).After(time.Now()) {
return fmt.Errorf("token not yet valid")
}
}
return nil
}
// GetOpenIDConfiguration 获取OpenID配置
func (c *SSOClient) GetOpenIDConfiguration(ctx context.Context) (*models.OpenIDConfiguration, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/.well-known/openid-configuration", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get OpenID configuration: %s", resp.Status)
}
var config models.OpenIDConfiguration
if err := json.Unmarshal(body, &config); err != nil {
return nil, fmt.Errorf("failed to parse OpenID configuration: %w", err)
}
return &config, nil
}
// GetJWKS 获取JWKS
func (c *SSOClient) GetJWKS(ctx context.Context) (*models.JWKS, error) {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/jwks.json", c.config.SSOServerURL), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get JWKS: %s", resp.Status)
}
var jwks models.JWKS
if err := json.Unmarshal(body, &jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS: %w", err)
}
return &jwks, nil
}
// Logout 登出
func (c *SSOClient) Logout(ctx context.Context, accessToken string) error {
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/logout-pre", c.config.SSOServerURL), nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.http.Do(req)
if err != nil {
c.logger.Error("failed to make request", zap.Error(err))
return fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
c.logger.Error("failed to logout", zap.String("status", resp.Status), zap.String("body", string(body)))
return fmt.Errorf("logout failed: %s - %s", resp.Status, string(body))
}
return nil
}
// 生成PKCE代码验证器
func generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}