Files
goalfylearning-admin/internal/api/handlers/sso_handler.go

268 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handlers
import (
"goalfymax-admin/internal/models"
"goalfymax-admin/internal/services"
"goalfymax-admin/pkg/utils"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// SSOHandler SSO处理器
type SSOHandler struct {
ssoService services.SSOService
response *utils.Response
logger *utils.Logger
}
// NewSSOHandler 创建SSO处理器
func NewSSOHandler(ssoService services.SSOService, logger *utils.Logger) *SSOHandler {
return &SSOHandler{
ssoService: ssoService,
response: utils.NewResponse(),
logger: logger,
}
}
// HandleSSOLogin 处理SSO登录请求合并登录和回调逻辑
func (h *SSOHandler) HandleSSOLogin(c *gin.Context) {
if c.Request.Method != http.MethodPost {
h.response.BadRequest(c, "Method not allowed")
return
}
var req models.SSOCallbackRequest
if err := c.ShouldBindJSON(&req); err != nil {
// 解析失败时,走登录逻辑
h.handleLoginLogic(c, models.SSOLoginRequest{})
return
}
// 如果code为空走登录逻辑
if req.Code == "" {
h.handleLoginLogic(c, models.SSOLoginRequest{})
return
}
// 如果code不为空走回调逻辑
h.handleCallbackLogic(c, req)
}
// HandleSSOCallback 处理SSO回调
func (h *SSOHandler) HandleSSOCallback(c *gin.Context) {
if c.Request.Method != http.MethodPost {
h.response.BadRequest(c, "Method not allowed")
return
}
var req models.SSOCallbackRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.response.ValidateError(c, err)
return
}
// 验证参数
if req.Code == "" || req.State == "" {
h.response.BadRequest(c, "Code and state are required")
return
}
// 调用服务层处理回调
response, err := h.ssoService.HandleCallback(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to handle SSO callback", zap.Error(err))
// 返回更具体的错误信息,避免前端重复尝试
if strings.Contains(err.Error(), "password") {
h.response.BadRequest(c, "数据库表结构错误,请联系管理员")
} else {
h.response.InternalServerError(c, "SSO登录处理失败请稍后重试")
}
return
}
h.response.Success(c, response)
}
// HandleRefreshToken 处理令牌刷新
func (h *SSOHandler) HandleRefreshToken(c *gin.Context) {
if c.Request.Method != http.MethodPost {
h.response.BadRequest(c, "Method not allowed")
return
}
var req models.RefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.response.ValidateError(c, err)
return
}
if req.RefreshToken == "" {
h.response.BadRequest(c, "Refresh token is required")
return
}
// 调用服务层刷新令牌
response, err := h.ssoService.RefreshToken(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to refresh token", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, response)
}
// HandleLogout 处理登出请求
func (h *SSOHandler) HandleLogout(c *gin.Context) {
if c.Request.Method != http.MethodPost {
h.response.BadRequest(c, "Method not allowed")
return
}
// 从Authorization头获取访问令牌
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
h.response.Unauthorized(c, "Authorization header is required")
return
}
// 提取Bearer令牌
token := ""
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
if token == "" {
h.response.Unauthorized(c, "Invalid authorization header")
return
}
// 调用服务层登出
response, err := h.ssoService.Logout(c.Request.Context(), token)
if err != nil {
h.logger.Error("failed to logout", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, response)
}
// HandleUserInfo 处理用户信息请求
func (h *SSOHandler) HandleUserInfo(c *gin.Context) {
if c.Request.Method != http.MethodGet {
h.response.BadRequest(c, "Method not allowed")
return
}
// 从Authorization头获取访问令牌
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
h.response.Unauthorized(c, "Authorization header is required")
return
}
// 提取Bearer令牌
token := ""
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
if token == "" {
h.response.Unauthorized(c, "Invalid authorization header")
return
}
// 调用服务层获取用户信息
response, err := h.ssoService.GetUserInfo(c.Request.Context(), token)
if err != nil {
h.logger.Error("failed to get user info", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, response)
}
// handleLoginLogic 处理登录逻辑
func (h *SSOHandler) handleLoginLogic(c *gin.Context, req models.SSOLoginRequest) {
// 调用服务层初始化登录
response, err := h.ssoService.InitiateLogin(c.Request.Context())
if err != nil {
h.logger.Error("failed to initiate SSO login", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, response)
}
// handleCallbackLogic 处理回调逻辑
func (h *SSOHandler) handleCallbackLogic(c *gin.Context, req models.SSOCallbackRequest) {
// 验证参数
if req.State == "" {
// 参数缺失时,走登录逻辑
h.handleLoginLogic(c, models.SSOLoginRequest{})
return
}
// 调用服务层处理回调
response, err := h.ssoService.HandleCallback(c.Request.Context(), &req)
if err != nil {
h.logger.Error("failed to handle SSO callback", zap.Error(err))
// 回调失败时,走登录逻辑
h.handleLoginLogic(c, models.SSOLoginRequest{})
return
}
h.response.Success(c, response)
}
// GetOnlineUsers 获取在线用户列表
func (h *SSOHandler) GetOnlineUsers(c *gin.Context) {
users, err := h.ssoService.GetOnlineUsers(c.Request.Context())
if err != nil {
h.logger.Error("failed to get online users", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, users)
}
// GetOnlineUserCount 获取在线用户数量
func (h *SSOHandler) GetOnlineUserCount(c *gin.Context) {
count, err := h.ssoService.GetOnlineUserCount(c.Request.Context())
if err != nil {
h.logger.Error("failed to get online user count", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, gin.H{"count": count})
}
// BatchLogout 批量登出用户
func (h *SSOHandler) BatchLogout(c *gin.Context) {
var req struct {
UserIDs []int `json:"user_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
h.response.ValidateError(c, err)
return
}
err := h.ssoService.BatchUserLogout(c.Request.Context(), req.UserIDs)
if err != nil {
h.logger.Error("failed to batch logout users", zap.Error(err))
h.response.InternalServerError(c, err.Error())
return
}
h.response.Success(c, "批量登出成功")
}