Files

291 lines
7.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"fmt"
"goalfymax-admin/internal/models"
"goalfymax-admin/internal/storage"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// SessionManager 会话管理器接口
type SessionManager interface {
GetSession(ctx context.Context, sessionID string) (*models.Session, error)
SetSession(ctx context.Context, sessionID string, session *models.Session) error
DeleteSession(ctx context.Context, sessionID string) error
}
// MemorySessionManager 内存会话管理器
type MemorySessionManager struct {
sessions map[string]*models.Session
}
// NewMemorySessionManager 创建内存会话管理器
func NewMemorySessionManager() *MemorySessionManager {
return &MemorySessionManager{
sessions: make(map[string]*models.Session),
}
}
func (m *MemorySessionManager) GetSession(ctx context.Context, sessionID string) (*models.Session, error) {
session, exists := m.sessions[sessionID]
if !exists {
return nil, fmt.Errorf("session not found")
}
if time.Now().After(session.ExpiresAt) {
delete(m.sessions, sessionID)
return nil, fmt.Errorf("session expired")
}
return session, nil
}
func (m *MemorySessionManager) SetSession(ctx context.Context, sessionID string, session *models.Session) error {
m.sessions[sessionID] = session
return nil
}
func (m *MemorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
delete(m.sessions, sessionID)
return nil
}
// AuthMiddleware 认证中间件
type AuthMiddleware struct {
client *SSOClient
sessionManager SessionManager
loginURL string
validationMode string // "sso" 或 "jwt"
}
// NewAuthMiddleware 创建认证中间件
func NewAuthMiddleware(client *SSOClient, sessionManager SessionManager, loginURL string) *AuthMiddleware {
return &AuthMiddleware{
client: client,
sessionManager: sessionManager,
loginURL: loginURL,
validationMode: "sso", // 默认使用SSO验证模式
}
}
// SetValidationMode 设置验证模式
func (m *AuthMiddleware) SetValidationMode(mode string) {
m.validationMode = mode
}
// RequireAuth 要求认证的中间件
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization头获取访问令牌
var token string = ""
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
authHeader = c.Query("token")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "Authorization header is required",
})
c.Abort()
return
}
token = strings.Trim(authHeader, " ")
} else {
// 提取Bearer令牌
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": "Invalid authorization header",
})
c.Abort()
return
}
// 直接调用第三方SSO服务验证token并获取用户信息
// 不再进行本地JWT验证而是通过调用第三方API来验证token的有效性
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"message": fmt.Sprintf("Invalid token or failed to get user info: %v", err),
})
c.Abort()
return
}
// 解析SSO用户ID为数值
userID, err := strconv.ParseUint(userInfo.Sub, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_user_id",
"message": "Invalid user ID in token",
})
c.Abort()
return
}
userIDUint := uint(userID)
// 查找或创建本地用户
user, err := m.findOrCreateUser(userIDUint, userInfo)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "user_creation_failed",
"message": "Failed to create or find user",
})
c.Abort()
return
}
// 将用户信息添加到上下文
c.Set("user", userInfo)
c.Set("user_id", userIDUint) // 使用转换后的用户ID
c.Set("local_user", user) // 本地用户对象
c.Set("token", token)
c.Next()
}
}
// OptionalAuth 可选认证的中间件
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization头获取访问令牌
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
// 提取Bearer令牌
token := ""
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
if token == "" {
c.Next()
return
}
// 直接调用第三方SSO服务验证token并获取用户信息
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
if err != nil {
c.Next()
return
}
// 将用户信息添加到上下文
c.Set("user", userInfo)
c.Set("token", token)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户信息
func GetUserFromContext(c *gin.Context) (*models.UserInfo, bool) {
user, exists := c.Get("user")
if !exists {
return nil, false
}
userInfo, ok := user.(*models.UserInfo)
return userInfo, ok
}
// GetUserIDFromContext 从上下文中获取用户ID
func GetUserIDFromContext(c *gin.Context) (int, bool) {
userID, exists := c.Get("user_id")
if !exists {
return 0, false
}
userIDInt, ok := userID.(int)
return userIDInt, ok
}
// GetTokenFromContext 从上下文中获取令牌
func GetTokenFromContext(c *gin.Context) (string, bool) {
token, exists := c.Get("token")
if !exists {
return "", false
}
tokenStr, ok := token.(string)
return tokenStr, ok
}
// findOrCreateUser 查找或创建用户
func (m *AuthMiddleware) findOrCreateUser(userID uint, userInfo *models.UserInfo) (*models.User, error) {
// 尝试查找现有用户
var user models.User
err := storage.DB.Where("id = ?", userID).First(&user).Error
if err == nil {
// 用户存在,更新登录信息
now := time.Now()
user.LastLoginAt = &now
user.LoginCount++
// 更新用户信息如果SSO信息有变化
if userInfo.Name != "" && user.Nickname != userInfo.Name {
user.Nickname = userInfo.Name
}
if userInfo.Email != "" && user.Email != userInfo.Email {
user.Email = userInfo.Email
}
err = storage.DB.Save(&user).Error
if err != nil {
return nil, fmt.Errorf("更新用户信息失败: %w", err)
}
return &user, nil
}
if err != gorm.ErrRecordNotFound {
return nil, fmt.Errorf("查询用户失败: %w", err)
}
// 用户不存在,创建新用户
now := time.Now()
user = models.User{
BaseModel: models.BaseModel{
ID: userID,
CreatedAt: now,
UpdatedAt: now,
},
Username: userInfo.PreferredUsername,
Email: userInfo.Email,
Nickname: userInfo.Name,
Status: 1, // 默认启用
SSOProvider: "default", // 可以根据实际情况设置
LastLoginAt: &now,
LoginCount: 1,
}
// 如果PreferredUsername为空使用Email作为用户名
if user.Username == "" {
user.Username = userInfo.Email
}
err = storage.DB.Create(&user).Error
if err != nil {
return nil, fmt.Errorf("创建用户失败: %w", err)
}
return &user, nil
}