package handlers import ( "goalfymax-admin/internal/models" "goalfymax-admin/internal/services" "goalfymax-admin/pkg/utils" "net/http" "strings" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // SSOHandler SSO处理器 type SSOHandler struct { ssoService services.SSOService response *utils.Response logger *utils.Logger } // NewSSOHandler 创建SSO处理器 func NewSSOHandler(ssoService services.SSOService, logger *utils.Logger) *SSOHandler { return &SSOHandler{ ssoService: ssoService, response: utils.NewResponse(), logger: logger, } } // HandleSSOLogin 处理SSO登录请求(合并登录和回调逻辑) func (h *SSOHandler) HandleSSOLogin(c *gin.Context) { if c.Request.Method != http.MethodPost { h.response.BadRequest(c, "Method not allowed") return } var req models.SSOCallbackRequest if err := c.ShouldBindJSON(&req); err != nil { // 解析失败时,走登录逻辑 h.handleLoginLogic(c, models.SSOLoginRequest{}) return } // 如果code为空,走登录逻辑 if req.Code == "" { h.handleLoginLogic(c, models.SSOLoginRequest{}) return } // 如果code不为空,走回调逻辑 h.handleCallbackLogic(c, req) } // HandleSSOCallback 处理SSO回调 func (h *SSOHandler) HandleSSOCallback(c *gin.Context) { if c.Request.Method != http.MethodPost { h.response.BadRequest(c, "Method not allowed") return } var req models.SSOCallbackRequest if err := c.ShouldBindJSON(&req); err != nil { h.response.ValidateError(c, err) return } // 验证参数 if req.Code == "" || req.State == "" { h.response.BadRequest(c, "Code and state are required") return } // 调用服务层处理回调 response, err := h.ssoService.HandleCallback(c.Request.Context(), &req) if err != nil { h.logger.Error("failed to handle SSO callback", zap.Error(err)) // 返回更具体的错误信息,避免前端重复尝试 if strings.Contains(err.Error(), "password") { h.response.BadRequest(c, "数据库表结构错误,请联系管理员") } else { h.response.InternalServerError(c, "SSO登录处理失败,请稍后重试") } return } h.response.Success(c, response) } // HandleRefreshToken 处理令牌刷新 func (h *SSOHandler) HandleRefreshToken(c *gin.Context) { if c.Request.Method != http.MethodPost { h.response.BadRequest(c, "Method not allowed") return } var req models.RefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { h.response.ValidateError(c, err) return } if req.RefreshToken == "" { h.response.BadRequest(c, "Refresh token is required") return } // 调用服务层刷新令牌 response, err := h.ssoService.RefreshToken(c.Request.Context(), &req) if err != nil { h.logger.Error("failed to refresh token", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, response) } // HandleLogout 处理登出请求 func (h *SSOHandler) HandleLogout(c *gin.Context) { if c.Request.Method != http.MethodPost { h.response.BadRequest(c, "Method not allowed") return } // 从Authorization头获取访问令牌 authHeader := c.GetHeader("Authorization") if authHeader == "" { h.response.Unauthorized(c, "Authorization header is required") return } // 提取Bearer令牌 token := "" if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token = authHeader[7:] } if token == "" { h.response.Unauthorized(c, "Invalid authorization header") return } // 调用服务层登出 response, err := h.ssoService.Logout(c.Request.Context(), token) if err != nil { h.logger.Error("failed to logout", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, response) } // HandleUserInfo 处理用户信息请求 func (h *SSOHandler) HandleUserInfo(c *gin.Context) { if c.Request.Method != http.MethodGet { h.response.BadRequest(c, "Method not allowed") return } // 从Authorization头获取访问令牌 authHeader := c.GetHeader("Authorization") if authHeader == "" { h.response.Unauthorized(c, "Authorization header is required") return } // 提取Bearer令牌 token := "" if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token = authHeader[7:] } if token == "" { h.response.Unauthorized(c, "Invalid authorization header") return } // 调用服务层获取用户信息 response, err := h.ssoService.GetUserInfo(c.Request.Context(), token) if err != nil { h.logger.Error("failed to get user info", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, response) } // handleLoginLogic 处理登录逻辑 func (h *SSOHandler) handleLoginLogic(c *gin.Context, req models.SSOLoginRequest) { // 调用服务层初始化登录 response, err := h.ssoService.InitiateLogin(c.Request.Context()) if err != nil { h.logger.Error("failed to initiate SSO login", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, response) } // handleCallbackLogic 处理回调逻辑 func (h *SSOHandler) handleCallbackLogic(c *gin.Context, req models.SSOCallbackRequest) { // 验证参数 if req.State == "" { // 参数缺失时,走登录逻辑 h.handleLoginLogic(c, models.SSOLoginRequest{}) return } // 调用服务层处理回调 response, err := h.ssoService.HandleCallback(c.Request.Context(), &req) if err != nil { h.logger.Error("failed to handle SSO callback", zap.Error(err)) // 回调失败时,走登录逻辑 h.handleLoginLogic(c, models.SSOLoginRequest{}) return } h.response.Success(c, response) } // GetOnlineUsers 获取在线用户列表 func (h *SSOHandler) GetOnlineUsers(c *gin.Context) { users, err := h.ssoService.GetOnlineUsers(c.Request.Context()) if err != nil { h.logger.Error("failed to get online users", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, users) } // GetOnlineUserCount 获取在线用户数量 func (h *SSOHandler) GetOnlineUserCount(c *gin.Context) { count, err := h.ssoService.GetOnlineUserCount(c.Request.Context()) if err != nil { h.logger.Error("failed to get online user count", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, gin.H{"count": count}) } // BatchLogout 批量登出用户 func (h *SSOHandler) BatchLogout(c *gin.Context) { var req struct { UserIDs []int `json:"user_ids" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { h.response.ValidateError(c, err) return } err := h.ssoService.BatchUserLogout(c.Request.Context(), req.UserIDs) if err != nil { h.logger.Error("failed to batch logout users", zap.Error(err)) h.response.InternalServerError(c, err.Error()) return } h.response.Success(c, "批量登出成功") }