291 lines
7.0 KiB
Go
291 lines
7.0 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"goalfymax-admin/internal/models"
|
||
"goalfymax-admin/internal/storage"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// SessionManager 会话管理器接口
|
||
type SessionManager interface {
|
||
GetSession(ctx context.Context, sessionID string) (*models.Session, error)
|
||
SetSession(ctx context.Context, sessionID string, session *models.Session) error
|
||
DeleteSession(ctx context.Context, sessionID string) error
|
||
}
|
||
|
||
// MemorySessionManager 内存会话管理器
|
||
type MemorySessionManager struct {
|
||
sessions map[string]*models.Session
|
||
}
|
||
|
||
// NewMemorySessionManager 创建内存会话管理器
|
||
func NewMemorySessionManager() *MemorySessionManager {
|
||
return &MemorySessionManager{
|
||
sessions: make(map[string]*models.Session),
|
||
}
|
||
}
|
||
|
||
func (m *MemorySessionManager) GetSession(ctx context.Context, sessionID string) (*models.Session, error) {
|
||
session, exists := m.sessions[sessionID]
|
||
if !exists {
|
||
return nil, fmt.Errorf("session not found")
|
||
}
|
||
|
||
if time.Now().After(session.ExpiresAt) {
|
||
delete(m.sessions, sessionID)
|
||
return nil, fmt.Errorf("session expired")
|
||
}
|
||
|
||
return session, nil
|
||
}
|
||
|
||
func (m *MemorySessionManager) SetSession(ctx context.Context, sessionID string, session *models.Session) error {
|
||
m.sessions[sessionID] = session
|
||
return nil
|
||
}
|
||
|
||
func (m *MemorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
|
||
delete(m.sessions, sessionID)
|
||
return nil
|
||
}
|
||
|
||
// AuthMiddleware 认证中间件
|
||
type AuthMiddleware struct {
|
||
client *SSOClient
|
||
sessionManager SessionManager
|
||
loginURL string
|
||
validationMode string // "sso" 或 "jwt"
|
||
}
|
||
|
||
// NewAuthMiddleware 创建认证中间件
|
||
func NewAuthMiddleware(client *SSOClient, sessionManager SessionManager, loginURL string) *AuthMiddleware {
|
||
return &AuthMiddleware{
|
||
client: client,
|
||
sessionManager: sessionManager,
|
||
loginURL: loginURL,
|
||
validationMode: "sso", // 默认使用SSO验证模式
|
||
}
|
||
}
|
||
|
||
// SetValidationMode 设置验证模式
|
||
func (m *AuthMiddleware) SetValidationMode(mode string) {
|
||
m.validationMode = mode
|
||
}
|
||
|
||
// RequireAuth 要求认证的中间件
|
||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从Authorization头获取访问令牌
|
||
var token string = ""
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
authHeader = c.Query("token")
|
||
if authHeader == "" {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "unauthorized",
|
||
"message": "Authorization header is required",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
token = strings.Trim(authHeader, " ")
|
||
} else {
|
||
// 提取Bearer令牌
|
||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||
token = authHeader[7:]
|
||
}
|
||
}
|
||
|
||
if token == "" {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "unauthorized",
|
||
"message": "Invalid authorization header",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 直接调用第三方SSO服务验证token并获取用户信息
|
||
// 不再进行本地JWT验证,而是通过调用第三方API来验证token的有效性
|
||
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
|
||
if err != nil {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"error": "unauthorized",
|
||
"message": fmt.Sprintf("Invalid token or failed to get user info: %v", err),
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 解析SSO用户ID为数值
|
||
userID, err := strconv.ParseUint(userInfo.Sub, 10, 32)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": "invalid_user_id",
|
||
"message": "Invalid user ID in token",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
userIDUint := uint(userID)
|
||
|
||
// 查找或创建本地用户
|
||
user, err := m.findOrCreateUser(userIDUint, userInfo)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": "user_creation_failed",
|
||
"message": "Failed to create or find user",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 将用户信息添加到上下文
|
||
c.Set("user", userInfo)
|
||
c.Set("user_id", userIDUint) // 使用转换后的用户ID
|
||
c.Set("local_user", user) // 本地用户对象
|
||
c.Set("token", token)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// OptionalAuth 可选认证的中间件
|
||
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从Authorization头获取访问令牌
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 提取Bearer令牌
|
||
token := ""
|
||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||
token = authHeader[7:]
|
||
}
|
||
|
||
if token == "" {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 直接调用第三方SSO服务验证token并获取用户信息
|
||
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
|
||
if err != nil {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 将用户信息添加到上下文
|
||
c.Set("user", userInfo)
|
||
c.Set("token", token)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// GetUserFromContext 从上下文中获取用户信息
|
||
func GetUserFromContext(c *gin.Context) (*models.UserInfo, bool) {
|
||
user, exists := c.Get("user")
|
||
if !exists {
|
||
return nil, false
|
||
}
|
||
|
||
userInfo, ok := user.(*models.UserInfo)
|
||
return userInfo, ok
|
||
}
|
||
|
||
// GetUserIDFromContext 从上下文中获取用户ID
|
||
func GetUserIDFromContext(c *gin.Context) (int, bool) {
|
||
userID, exists := c.Get("user_id")
|
||
if !exists {
|
||
return 0, false
|
||
}
|
||
|
||
userIDInt, ok := userID.(int)
|
||
return userIDInt, ok
|
||
}
|
||
|
||
// GetTokenFromContext 从上下文中获取令牌
|
||
func GetTokenFromContext(c *gin.Context) (string, bool) {
|
||
token, exists := c.Get("token")
|
||
if !exists {
|
||
return "", false
|
||
}
|
||
|
||
tokenStr, ok := token.(string)
|
||
return tokenStr, ok
|
||
}
|
||
|
||
// findOrCreateUser 查找或创建用户
|
||
func (m *AuthMiddleware) findOrCreateUser(userID uint, userInfo *models.UserInfo) (*models.User, error) {
|
||
// 尝试查找现有用户
|
||
var user models.User
|
||
err := storage.DB.Where("id = ?", userID).First(&user).Error
|
||
|
||
if err == nil {
|
||
// 用户存在,更新登录信息
|
||
now := time.Now()
|
||
user.LastLoginAt = &now
|
||
user.LoginCount++
|
||
|
||
// 更新用户信息(如果SSO信息有变化)
|
||
if userInfo.Name != "" && user.Nickname != userInfo.Name {
|
||
user.Nickname = userInfo.Name
|
||
}
|
||
if userInfo.Email != "" && user.Email != userInfo.Email {
|
||
user.Email = userInfo.Email
|
||
}
|
||
|
||
err = storage.DB.Save(&user).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("更新用户信息失败: %w", err)
|
||
}
|
||
|
||
return &user, nil
|
||
}
|
||
|
||
if err != gorm.ErrRecordNotFound {
|
||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||
}
|
||
|
||
// 用户不存在,创建新用户
|
||
now := time.Now()
|
||
user = models.User{
|
||
BaseModel: models.BaseModel{
|
||
ID: userID,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
},
|
||
Username: userInfo.PreferredUsername,
|
||
Email: userInfo.Email,
|
||
Nickname: userInfo.Name,
|
||
Status: 1, // 默认启用
|
||
SSOProvider: "default", // 可以根据实际情况设置
|
||
LastLoginAt: &now,
|
||
LoginCount: 1,
|
||
}
|
||
|
||
// 如果PreferredUsername为空,使用Email作为用户名
|
||
if user.Username == "" {
|
||
user.Username = userInfo.Email
|
||
}
|
||
|
||
err = storage.DB.Create(&user).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建用户失败: %w", err)
|
||
}
|
||
|
||
return &user, nil
|
||
}
|