package middleware import ( "context" "fmt" "goalfymax-admin/internal/models" "goalfymax-admin/internal/storage" "net/http" "strconv" "strings" "time" "github.com/gin-gonic/gin" "gorm.io/gorm" ) // SessionManager 会话管理器接口 type SessionManager interface { GetSession(ctx context.Context, sessionID string) (*models.Session, error) SetSession(ctx context.Context, sessionID string, session *models.Session) error DeleteSession(ctx context.Context, sessionID string) error } // MemorySessionManager 内存会话管理器 type MemorySessionManager struct { sessions map[string]*models.Session } // NewMemorySessionManager 创建内存会话管理器 func NewMemorySessionManager() *MemorySessionManager { return &MemorySessionManager{ sessions: make(map[string]*models.Session), } } func (m *MemorySessionManager) GetSession(ctx context.Context, sessionID string) (*models.Session, error) { session, exists := m.sessions[sessionID] if !exists { return nil, fmt.Errorf("session not found") } if time.Now().After(session.ExpiresAt) { delete(m.sessions, sessionID) return nil, fmt.Errorf("session expired") } return session, nil } func (m *MemorySessionManager) SetSession(ctx context.Context, sessionID string, session *models.Session) error { m.sessions[sessionID] = session return nil } func (m *MemorySessionManager) DeleteSession(ctx context.Context, sessionID string) error { delete(m.sessions, sessionID) return nil } // AuthMiddleware 认证中间件 type AuthMiddleware struct { client *SSOClient sessionManager SessionManager loginURL string validationMode string // "sso" 或 "jwt" } // NewAuthMiddleware 创建认证中间件 func NewAuthMiddleware(client *SSOClient, sessionManager SessionManager, loginURL string) *AuthMiddleware { return &AuthMiddleware{ client: client, sessionManager: sessionManager, loginURL: loginURL, validationMode: "sso", // 默认使用SSO验证模式 } } // SetValidationMode 设置验证模式 func (m *AuthMiddleware) SetValidationMode(mode string) { m.validationMode = mode } // RequireAuth 要求认证的中间件 func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { // 从Authorization头获取访问令牌 var token string = "" authHeader := c.GetHeader("Authorization") if authHeader == "" { authHeader = c.Query("token") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Authorization header is required", }) c.Abort() return } token = strings.Trim(authHeader, " ") } else { // 提取Bearer令牌 if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token = authHeader[7:] } } if token == "" { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": "Invalid authorization header", }) c.Abort() return } // 直接调用第三方SSO服务验证token并获取用户信息 // 不再进行本地JWT验证,而是通过调用第三方API来验证token的有效性 userInfo, err := m.client.GetUserInfo(c.Request.Context(), token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", "message": fmt.Sprintf("Invalid token or failed to get user info: %v", err), }) c.Abort() return } // 解析SSO用户ID为数值 userID, err := strconv.ParseUint(userInfo.Sub, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_user_id", "message": "Invalid user ID in token", }) c.Abort() return } userIDUint := uint(userID) // 查找或创建本地用户 user, err := m.findOrCreateUser(userIDUint, userInfo) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "user_creation_failed", "message": "Failed to create or find user", }) c.Abort() return } // 将用户信息添加到上下文 c.Set("user", userInfo) c.Set("user_id", userIDUint) // 使用转换后的用户ID c.Set("local_user", user) // 本地用户对象 c.Set("token", token) c.Next() } } // OptionalAuth 可选认证的中间件 func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { return func(c *gin.Context) { // 从Authorization头获取访问令牌 authHeader := c.GetHeader("Authorization") if authHeader == "" { c.Next() return } // 提取Bearer令牌 token := "" if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token = authHeader[7:] } if token == "" { c.Next() return } // 直接调用第三方SSO服务验证token并获取用户信息 userInfo, err := m.client.GetUserInfo(c.Request.Context(), token) if err != nil { c.Next() return } // 将用户信息添加到上下文 c.Set("user", userInfo) c.Set("token", token) c.Next() } } // GetUserFromContext 从上下文中获取用户信息 func GetUserFromContext(c *gin.Context) (*models.UserInfo, bool) { user, exists := c.Get("user") if !exists { return nil, false } userInfo, ok := user.(*models.UserInfo) return userInfo, ok } // GetUserIDFromContext 从上下文中获取用户ID func GetUserIDFromContext(c *gin.Context) (int, bool) { userID, exists := c.Get("user_id") if !exists { return 0, false } userIDInt, ok := userID.(int) return userIDInt, ok } // GetTokenFromContext 从上下文中获取令牌 func GetTokenFromContext(c *gin.Context) (string, bool) { token, exists := c.Get("token") if !exists { return "", false } tokenStr, ok := token.(string) return tokenStr, ok } // findOrCreateUser 查找或创建用户 func (m *AuthMiddleware) findOrCreateUser(userID uint, userInfo *models.UserInfo) (*models.User, error) { // 尝试查找现有用户 var user models.User err := storage.DB.Where("id = ?", userID).First(&user).Error if err == nil { // 用户存在,更新登录信息 now := time.Now() user.LastLoginAt = &now user.LoginCount++ // 更新用户信息(如果SSO信息有变化) if userInfo.Name != "" && user.Nickname != userInfo.Name { user.Nickname = userInfo.Name } if userInfo.Email != "" && user.Email != userInfo.Email { user.Email = userInfo.Email } err = storage.DB.Save(&user).Error if err != nil { return nil, fmt.Errorf("更新用户信息失败: %w", err) } return &user, nil } if err != gorm.ErrRecordNotFound { return nil, fmt.Errorf("查询用户失败: %w", err) } // 用户不存在,创建新用户 now := time.Now() user = models.User{ BaseModel: models.BaseModel{ ID: userID, CreatedAt: now, UpdatedAt: now, }, Username: userInfo.PreferredUsername, Email: userInfo.Email, Nickname: userInfo.Name, Status: 1, // 默认启用 SSOProvider: "default", // 可以根据实际情况设置 LastLoginAt: &now, LoginCount: 1, } // 如果PreferredUsername为空,使用Email作为用户名 if user.Username == "" { user.Username = userInfo.Email } err = storage.DB.Create(&user).Error if err != nil { return nil, fmt.Errorf("创建用户失败: %w", err) } return &user, nil }