Files
goalfylearning-admin/internal/services/sso_service.go

619 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"context"
"fmt"
"goalfymax-admin/internal/models"
"goalfymax-admin/internal/storage"
"goalfymax-admin/pkg/middleware"
"goalfymax-admin/pkg/utils"
"net/url"
"strconv"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"gorm.io/gorm"
)
// SSOService SSO服务接口
type SSOService interface {
// 登录相关
InitiateLogin(ctx context.Context) (*models.SSOLoginResponse, error)
HandleCallback(ctx context.Context, req *models.SSOCallbackRequest) (*models.SSOCallbackResponse, error)
RefreshToken(ctx context.Context, req *models.RefreshTokenRequest) (*models.RefreshTokenResponse, error)
Logout(ctx context.Context, token string) (*models.LogoutResponse, error)
GetUserInfo(ctx context.Context, token string) (*models.UserInfoResponse, error)
// 登录信息管理
RecordUserLogin(ctx context.Context, req *UserLoginRequest) error
UserLogout(ctx context.Context, userID int, uuid string) error
GetUserLoginInfo(ctx context.Context, userID int) (*models.LoginInfo, error)
IsUserOnline(ctx context.Context, userID int) (bool, error)
GetOnlineUsers(ctx context.Context) ([]*models.LoginInfo, error)
GetOnlineUserCount(ctx context.Context) (int64, error)
BatchUserLogout(ctx context.Context, userIDs []int) error
}
// UserLoginRequest 用户登录请求结构体
type UserLoginRequest struct {
UserID int `json:"user_id" binding:"required"`
UserName string `json:"user_name" binding:"required,max=100"`
Email string `json:"email" binding:"required,email,max=255"`
UUID string `json:"uuid" binding:"required"`
}
type ssoService struct {
client *middleware.SSOClient
pkceStateStorage storage.PKCEStateStorage
loginInfoStorage storage.LoginInfoStorage
rbacService RBACService
logger *utils.Logger
}
// NewSSOService 创建SSO服务实例
func NewSSOService(client *middleware.SSOClient, pkceStateStorage storage.PKCEStateStorage, loginInfoStorage storage.LoginInfoStorage, rbacService RBACService, logger *utils.Logger) SSOService {
return &ssoService{
client: client,
pkceStateStorage: pkceStateStorage,
loginInfoStorage: loginInfoStorage,
rbacService: rbacService,
logger: logger,
}
}
// InitiateLogin 初始化SSO登录
func (s *ssoService) InitiateLogin(ctx context.Context) (*models.SSOLoginResponse, error) {
// 生成状态参数
state := fmt.Sprintf("state_%d", time.Now().UnixNano())
// 获取授权URL
authURL, codeVerifier, err := s.client.GetAuthorizationURL(state)
if err != nil {
s.logger.Error("failed to generate authorization URL", zap.Error(err))
return nil, fmt.Errorf("生成授权URL失败: %w", err)
}
// 将state和codeVerifier存储到数据库
pkceState := &models.PKCEState{
State: state,
CodeVerifier: codeVerifier,
}
if err := s.pkceStateStorage.Create(pkceState); err != nil {
s.logger.Error("failed to store PKCE state", zap.Error(err))
return nil, fmt.Errorf("存储PKCE状态失败: %w", err)
}
return &models.SSOLoginResponse{
Success: true,
Message: "SSO login initiated",
AuthURL: authURL,
State: state,
CodeVerifier: "", // 不返回codeVerifier到前端保护安全
}, nil
}
// HandleCallback 处理SSO回调
func (s *ssoService) HandleCallback(ctx context.Context, req *models.SSOCallbackRequest) (*models.SSOCallbackResponse, error) {
// 从数据库获取PKCE状态信息
pkceState, err := s.pkceStateStorage.GetByState(req.State)
if err != nil {
s.logger.Error("failed to get PKCE state", zap.String("state", req.State), zap.Error(err))
return nil, fmt.Errorf("无效或过期的状态参数")
}
// 交换令牌
tokenResp, err := s.client.ExchangeCodeForToken(ctx, req.Code, pkceState.CodeVerifier)
if err != nil {
s.logger.Error("failed to exchange token", zap.Error(err))
return nil, fmt.Errorf("令牌交换失败: %w", err)
}
// 获取用户信息
userInfo, err := s.client.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
s.logger.Error("failed to get user info", zap.Error(err))
return nil, fmt.Errorf("获取用户信息失败: %w", err)
}
// 记录用户登录信息
uid, err := strconv.Atoi(userInfo.Sub)
if err != nil {
s.logger.Error("failed to convert user id", zap.String("sub", userInfo.Sub), zap.Error(err))
return nil, fmt.Errorf("无效的用户ID")
}
// 计算本次登录应赋予的角色IDsys_admin → 5否则 → 0
assignedRoleID := uint(0)
if len(userInfo.Roles) > 0 {
for _, r := range userInfo.Roles {
if r.Name == "sys_admin" {
assignedRoleID = uint(5)
break
}
}
}
if assignedRoleID == 0 {
s.logger.Info("本次登录未检测到 sys_admin 角色赋予无权限角色ID=0", zap.Int("userID", uid))
} else {
s.logger.Info("本次登录检测到 sys_admin 角色赋予角色ID=5", zap.Int("userID", uid))
}
// 查找或创建用户记录,并根据本次判定设置 role_idsys_admin=5否则=0
_, err = s.findOrCreateUserWithRole(uint(uid), userInfo, assignedRoleID)
if err != nil {
s.logger.Error("failed to find or create user", zap.Error(err))
return nil, fmt.Errorf("创建或查找用户失败: %w", err)
}
uuid := uuid.NewString()
err = s.RecordUserLogin(ctx, &UserLoginRequest{
UserID: uid,
UserName: userInfo.Name,
Email: userInfo.Email,
UUID: uuid,
})
if err != nil {
s.logger.Error("failed to record user login", zap.Int("user_id", uid), zap.Error(err))
// 不返回错误,继续处理
}
// 获取用户可访问页面信息
userPages, err := s.rbacService.GetUserAccessiblePages(uint(uid))
if err != nil {
s.logger.Warn("获取用户页面权限失败", zap.Error(err))
userPages = []string{}
}
s.logger.Info("获取用户可访问页面", zap.Uint("userID", uint(uid)), zap.Strings("pages", userPages))
// 获取用户角色信息(用于前端体验判断,后端仍以接口鉴权为准)
var user models.User
var userRoleName string
err = storage.DB.Where("id = ?", uint(uid)).First(&user).Error
if err != nil {
s.logger.Warn("获取用户信息失败", zap.Error(err))
} else {
role, rerr := s.rbacService.GetRoleByID(user.RoleID)
if rerr != nil {
s.logger.Warn("获取角色信息失败", zap.Error(rerr))
} else {
userRoleName = role.Name
s.logger.Info("获取用户角色", zap.Uint("userID", uint(uid)), zap.String("roleName", userRoleName))
}
}
// 构建包含页面权限的用户信息
userInfoWithPages := &models.UserInfo{
Sub: userInfo.Sub,
Name: userInfo.Name,
Email: userInfo.Email,
PreferredUsername: userInfo.PreferredUsername,
Pages: convertPagesFromPaths(userPages),
Roles: []models.Role{},
}
if userRoleName != "" {
userInfoWithPages.Roles = append(userInfoWithPages.Roles, models.Role{Name: userRoleName})
}
// 清理PKCE状态
_ = s.pkceStateStorage.DeleteByState(req.State)
return &models.SSOCallbackResponse{
Success: true,
Message: "SSO login successful",
AccessToken: tokenResp.AccessToken,
IDToken: tokenResp.IDToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
UserInfo: userInfoWithPages,
UUID: uuid,
}, nil
}
// convertPagesFromPaths 将页面路径转换为页面对象
func convertPagesFromPaths(paths []string) []models.Page {
var pages []models.Page
fmt.Printf("🔍 [convertPagesFromPaths] 输入路径: %v\n", paths)
// 从数据库获取完整的页面信息
err := storage.DB.Where("path IN ? AND is_active = TRUE AND deleted_at IS NULL", paths).Find(&pages).Error
if err != nil {
fmt.Printf("❌ [convertPagesFromPaths] 数据库查询失败: %v\n", err)
// 如果数据库查询失败,至少返回路径信息
for _, path := range paths {
pages = append(pages, models.Page{
Path: path,
IsActive: true, // 默认设置为激活状态
})
}
} else {
fmt.Printf("✅ [convertPagesFromPaths] 数据库查询成功,找到 %d 个页面\n", len(pages))
for _, page := range pages {
fmt.Printf(" - 页面: %s, 激活状态: %v\n", page.Path, page.IsActive)
}
}
return pages
}
// RefreshToken 刷新令牌
func (s *ssoService) RefreshToken(ctx context.Context, req *models.RefreshTokenRequest) (*models.RefreshTokenResponse, error) {
// 刷新令牌
tokenResp, err := s.client.RefreshToken(ctx, req.RefreshToken)
if err != nil {
s.logger.Error("failed to refresh token", zap.Error(err))
return nil, fmt.Errorf("令牌刷新失败: %w", err)
}
return &models.RefreshTokenResponse{
Success: true,
Message: "Token refreshed successfully",
AccessToken: tokenResp.AccessToken,
IDToken: tokenResp.IDToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
}, nil
}
// Logout 登出
func (s *ssoService) Logout(ctx context.Context, token string) (*models.LogoutResponse, error) {
// 获取用户信息用于记录登出
userInfo, err := s.client.GetUserInfo(ctx, token)
if err != nil {
s.logger.Error("failed to get user info during logout", zap.Error(err))
// 继续执行登出,不中断流程
}
// 调用SSO服务登出
err = s.client.Logout(ctx, token)
if err != nil {
s.logger.Error("failed to logout", zap.Error(err))
return nil, fmt.Errorf("登出失败: %w", err)
}
// 记录用户登出信息
if userInfo != nil {
uid, err := strconv.Atoi(userInfo.Sub)
if err != nil {
s.logger.Error("failed to convert user id during logout", zap.String("sub", userInfo.Sub), zap.Error(err))
} else {
// 先根据user_id查找用户的uuid
loginInfo, err := s.loginInfoStorage.GetByUserID(uid)
if err != nil {
s.logger.Error("failed to get user login info during logout", zap.Int("user_id", uid), zap.Error(err))
} else {
// 如果UUID为空直接根据user_id进行登出
if loginInfo.UUID == "" {
s.logger.Warn("UUID is empty, logging out by user_id only", zap.Int("user_id", uid))
err = s.loginInfoStorage.SetUserOffline(uid)
if err != nil {
s.logger.Error("failed to set user offline by user_id", zap.Int("user_id", uid), zap.Error(err))
}
} else {
// 使用找到的uuid进行登出
err = s.UserLogout(ctx, uid, loginInfo.UUID)
if err != nil {
s.logger.Error("failed to record user logout", zap.Int("user_id", uid), zap.Error(err))
}
}
}
}
}
redirectUrl := s.client.GetServerUrl() + "/oauth2/logout?redirect_uri=" + url.QueryEscape(s.client.GetRedirectUrl())
return &models.LogoutResponse{
Success: true,
Message: redirectUrl,
}, nil
}
// GetUserInfo 获取用户信息
func (s *ssoService) GetUserInfo(ctx context.Context, token string) (*models.UserInfoResponse, error) {
// 获取用户信息
userInfo, err := s.client.GetUserInfo(ctx, token)
if err != nil {
s.logger.Error("failed to get user info", zap.Error(err))
return nil, fmt.Errorf("获取用户信息失败: %w", err)
}
return &models.UserInfoResponse{
Success: true,
Message: "User info retrieved successfully",
UserInfo: userInfo,
}, nil
}
// RecordUserLogin 记录用户登录状态
func (s *ssoService) RecordUserLogin(ctx context.Context, req *UserLoginRequest) error {
// 检查用户是否已存在
existingUser, err := s.loginInfoStorage.GetByUserID(req.UserID)
if err != nil && err != gorm.ErrRecordNotFound {
s.logger.Error("failed to get existing user login info",
zap.Int("user_id", req.UserID),
zap.Error(err))
return fmt.Errorf("获取用户登录信息失败: %w", err)
}
if existingUser != nil {
// 用户存在直接更新状态
existingUser.IsOnline = true
existingUser.UUID = req.UUID // 同时更新UUID可能有变化
existingUser.UserName = req.UserName
existingUser.Email = req.Email
err = s.loginInfoStorage.Update(existingUser)
if err != nil {
s.logger.Error("failed to update user online status",
zap.Int("user_id", req.UserID),
zap.Error(err))
return fmt.Errorf("更新用户在线状态失败: %w", err)
}
s.logger.Info("user login status updated successfully",
zap.Int("user_id", req.UserID),
zap.String("user_name", req.UserName))
} else {
// 用户不存在,创建新用户并设置为在线
newLoginInfo := &models.LoginInfo{
UserID: req.UserID,
UserName: req.UserName,
Email: req.Email,
UUID: req.UUID,
IsOnline: true,
}
err = s.loginInfoStorage.Create(newLoginInfo)
if err != nil {
s.logger.Error("failed to create new user login info",
zap.Int("user_id", req.UserID),
zap.String("user_name", req.UserName),
zap.Error(err))
return fmt.Errorf("创建用户登录信息失败: %w", err)
}
s.logger.Info("new user login info created successfully",
zap.Int("user_id", req.UserID),
zap.String("user_name", req.UserName))
}
return nil
}
// UserLogout 用户登出
func (s *ssoService) UserLogout(ctx context.Context, userID int, uuid string) error {
// 检查用户是否存在
existingUser, err := s.loginInfoStorage.GetByUserIDAndUUID(userID, uuid)
if err != nil && err != gorm.ErrRecordNotFound {
s.logger.Error("failed to get user login info for logout",
zap.Int("user_id", userID),
zap.String("uuid", uuid),
zap.Error(err))
return fmt.Errorf("获取用户登录信息失败: %w", err)
}
if existingUser == nil {
// 用户不存在,不需要操作
s.logger.Info("user not found, no logout action needed",
zap.Int("user_id", userID),
zap.String("uuid", uuid))
return nil
}
// 检查用户是否已经离线
if !existingUser.IsOnline {
// 用户已经离线,不需要操作
s.logger.Info("user is already offline, no action needed",
zap.Int("user_id", userID),
zap.String("uuid", uuid),
zap.String("user_name", existingUser.UserName))
return nil
}
// 设置用户为离线状态
err = s.loginInfoStorage.SetUserOffline(userID)
if err != nil {
s.logger.Error("failed to set user offline",
zap.Int("user_id", userID),
zap.String("user_name", existingUser.UserName),
zap.Error(err))
return fmt.Errorf("设置用户离线状态失败: %w", err)
}
s.logger.Info("user logout successfully",
zap.Int("user_id", userID),
zap.String("uuid", uuid),
zap.String("user_name", existingUser.UserName))
return nil
}
// GetUserLoginInfo 获取用户登录信息
func (s *ssoService) GetUserLoginInfo(ctx context.Context, userID int) (*models.LoginInfo, error) {
loginInfo, err := s.loginInfoStorage.GetByUserID(userID)
if err != nil {
s.logger.Error("failed to get user login info",
zap.Int("user_id", userID),
zap.Error(err))
return nil, fmt.Errorf("获取用户登录信息失败: %w", err)
}
return loginInfo, nil
}
// IsUserOnline 检查用户是否在线
func (s *ssoService) IsUserOnline(ctx context.Context, userID int) (bool, error) {
loginInfo, err := s.GetUserLoginInfo(ctx, userID)
if err != nil {
return false, err
}
if loginInfo == nil {
return false, nil
}
return loginInfo.IsOnline, nil
}
// GetOnlineUsers 获取在线用户列表
func (s *ssoService) GetOnlineUsers(ctx context.Context) ([]*models.LoginInfo, error) {
onlineUsers, err := s.loginInfoStorage.ListOnlineUsers()
if err != nil {
s.logger.Error("failed to get online users", zap.Error(err))
return nil, fmt.Errorf("获取在线用户列表失败: %w", err)
}
return onlineUsers, nil
}
// GetOnlineUserCount 获取在线用户数量
func (s *ssoService) GetOnlineUserCount(ctx context.Context) (int64, error) {
count, err := s.loginInfoStorage.CountOnlineUsers()
if err != nil {
s.logger.Error("failed to count online users", zap.Error(err))
return 0, fmt.Errorf("统计在线用户数量失败: %w", err)
}
return count, nil
}
// BatchUserLogout 批量用户登出(可用于系统维护等场景)
func (s *ssoService) BatchUserLogout(ctx context.Context, userIDs []int) error {
if len(userIDs) == 0 {
return nil
}
for _, userID := range userIDs {
err := s.loginInfoStorage.SetUserOffline(userID)
if err != nil {
s.logger.Error("failed to set user offline in batch",
zap.Int("user_id", userID),
zap.Error(err))
// 继续处理其他用户,不中断整个批量操作
continue
}
}
s.logger.Info("batch user logout completed",
zap.Ints("user_ids", userIDs))
return nil
}
// findOrCreateUser 查找或创建用户
func (s *ssoService) findOrCreateUserWithRole(userID uint, userInfo *models.UserInfo, assignedRoleID uint) (*models.User, error) {
// 尝试查找现有用户
var user models.User
err := storage.DB.Where("id = ?", userID).First(&user).Error
if err == nil {
// 用户存在,更新登录信息
now := time.Now()
user.LastLoginAt = &now
user.LoginCount++
// 更新用户信息如果SSO信息有变化
if userInfo.Name != "" && user.Nickname != userInfo.Name {
user.Nickname = userInfo.Name
}
if userInfo.Email != "" && user.Email != userInfo.Email {
user.Email = userInfo.Email
}
// 同步规则:
// - 若SSO判定为普通(0)无论当前为何值一律降级为0
// - 若SSO判定为sys_admin(5)仅当当前为0时升级为5避免覆盖业务自定义角色
if assignedRoleID == 0 {
if user.RoleID != 0 {
s.logger.Info("降级用户角色(→0)", zap.Uint("userID", userID), zap.Uint("oldRoleID", user.RoleID))
user.RoleID = 0
} else {
s.logger.Debug("保持0角色不变", zap.Uint("userID", userID))
}
} else if assignedRoleID == 5 {
if user.RoleID == 0 {
s.logger.Info("升级用户角色(0→5)", zap.Uint("userID", userID))
user.RoleID = 5
} else {
s.logger.Debug("保持非0角色不变", zap.Uint("userID", userID), zap.Uint("currentRoleID", user.RoleID))
}
}
err = storage.DB.Save(&user).Error
if err != nil {
return nil, fmt.Errorf("更新用户信息失败: %w", err)
}
s.logger.Info("用户登录信息已更新",
zap.Uint("userID", userID),
zap.String("username", user.Username))
return &user, nil
}
if err != gorm.ErrRecordNotFound {
return nil, fmt.Errorf("查询用户失败: %w", err)
}
// 用户不存在,创建新用户
now := time.Now()
user = models.User{
BaseModel: models.BaseModel{
ID: userID,
CreatedAt: now,
UpdatedAt: now,
},
Username: userInfo.PreferredUsername,
Email: userInfo.Email,
Nickname: userInfo.Name,
Status: 1, // 默认启用
SSOProvider: "default", // 可以根据实际情况设置
LastLoginAt: &now,
LoginCount: 1,
RoleID: assignedRoleID,
}
// 如果PreferredUsername为空使用Email作为用户名
if user.Username == "" {
user.Username = userInfo.Email
}
err = storage.DB.Create(&user).Error
if err != nil {
return nil, fmt.Errorf("创建用户失败: %w", err)
}
s.logger.Info("新用户创建成功",
zap.Uint("userID", userID),
zap.String("username", user.Username),
zap.String("email", user.Email))
return &user, nil
}
// assignDefaultRole 分配默认角色L5全员
func (s *ssoService) assignDefaultRole(userID uint) error {
// 获取默认角色IDL5全员
var role models.Role
err := storage.DB.Where("is_default = ?", true).First(&role).Error
if err != nil {
return fmt.Errorf("获取默认角色失败: %w", err)
}
// 更新用户的角色ID
err = storage.DB.Model(&models.User{}).Where("id = ?", userID).Update("role_id", role.ID).Error
if err != nil {
return fmt.Errorf("分配默认角色失败: %w", err)
}
s.logger.Info("用户已分配默认角色",
zap.Uint("userID", userID),
zap.Uint("roleID", role.ID),
zap.String("roleName", role.Name))
return nil
}