268 lines
6.7 KiB
Go
268 lines
6.7 KiB
Go
package handlers
|
||
|
||
import (
|
||
"goalfymax-admin/internal/models"
|
||
"goalfymax-admin/internal/services"
|
||
"goalfymax-admin/pkg/utils"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// SSOHandler SSO处理器
|
||
type SSOHandler struct {
|
||
ssoService services.SSOService
|
||
response *utils.Response
|
||
logger *utils.Logger
|
||
}
|
||
|
||
// NewSSOHandler 创建SSO处理器
|
||
func NewSSOHandler(ssoService services.SSOService, logger *utils.Logger) *SSOHandler {
|
||
return &SSOHandler{
|
||
ssoService: ssoService,
|
||
response: utils.NewResponse(),
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// HandleSSOLogin 处理SSO登录请求(合并登录和回调逻辑)
|
||
func (h *SSOHandler) HandleSSOLogin(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
h.response.BadRequest(c, "Method not allowed")
|
||
return
|
||
}
|
||
|
||
var req models.SSOCallbackRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
// 解析失败时,走登录逻辑
|
||
h.handleLoginLogic(c, models.SSOLoginRequest{})
|
||
return
|
||
}
|
||
|
||
// 如果code为空,走登录逻辑
|
||
if req.Code == "" {
|
||
h.handleLoginLogic(c, models.SSOLoginRequest{})
|
||
return
|
||
}
|
||
|
||
// 如果code不为空,走回调逻辑
|
||
h.handleCallbackLogic(c, req)
|
||
}
|
||
|
||
// HandleSSOCallback 处理SSO回调
|
||
func (h *SSOHandler) HandleSSOCallback(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
h.response.BadRequest(c, "Method not allowed")
|
||
return
|
||
}
|
||
|
||
var req models.SSOCallbackRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
h.response.ValidateError(c, err)
|
||
return
|
||
}
|
||
|
||
// 验证参数
|
||
if req.Code == "" || req.State == "" {
|
||
h.response.BadRequest(c, "Code and state are required")
|
||
return
|
||
}
|
||
|
||
// 调用服务层处理回调
|
||
response, err := h.ssoService.HandleCallback(c.Request.Context(), &req)
|
||
if err != nil {
|
||
h.logger.Error("failed to handle SSO callback", zap.Error(err))
|
||
// 返回更具体的错误信息,避免前端重复尝试
|
||
if strings.Contains(err.Error(), "password") {
|
||
h.response.BadRequest(c, "数据库表结构错误,请联系管理员")
|
||
} else {
|
||
h.response.InternalServerError(c, "SSO登录处理失败,请稍后重试")
|
||
}
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// HandleRefreshToken 处理令牌刷新
|
||
func (h *SSOHandler) HandleRefreshToken(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
h.response.BadRequest(c, "Method not allowed")
|
||
return
|
||
}
|
||
|
||
var req models.RefreshTokenRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
h.response.ValidateError(c, err)
|
||
return
|
||
}
|
||
|
||
if req.RefreshToken == "" {
|
||
h.response.BadRequest(c, "Refresh token is required")
|
||
return
|
||
}
|
||
|
||
// 调用服务层刷新令牌
|
||
response, err := h.ssoService.RefreshToken(c.Request.Context(), &req)
|
||
if err != nil {
|
||
h.logger.Error("failed to refresh token", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// HandleLogout 处理登出请求
|
||
func (h *SSOHandler) HandleLogout(c *gin.Context) {
|
||
if c.Request.Method != http.MethodPost {
|
||
h.response.BadRequest(c, "Method not allowed")
|
||
return
|
||
}
|
||
|
||
// 从Authorization头获取访问令牌
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
h.response.Unauthorized(c, "Authorization header is required")
|
||
return
|
||
}
|
||
|
||
// 提取Bearer令牌
|
||
token := ""
|
||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||
token = authHeader[7:]
|
||
}
|
||
|
||
if token == "" {
|
||
h.response.Unauthorized(c, "Invalid authorization header")
|
||
return
|
||
}
|
||
|
||
// 调用服务层登出
|
||
response, err := h.ssoService.Logout(c.Request.Context(), token)
|
||
if err != nil {
|
||
h.logger.Error("failed to logout", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// HandleUserInfo 处理用户信息请求
|
||
func (h *SSOHandler) HandleUserInfo(c *gin.Context) {
|
||
if c.Request.Method != http.MethodGet {
|
||
h.response.BadRequest(c, "Method not allowed")
|
||
return
|
||
}
|
||
|
||
// 从Authorization头获取访问令牌
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
h.response.Unauthorized(c, "Authorization header is required")
|
||
return
|
||
}
|
||
|
||
// 提取Bearer令牌
|
||
token := ""
|
||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||
token = authHeader[7:]
|
||
}
|
||
|
||
if token == "" {
|
||
h.response.Unauthorized(c, "Invalid authorization header")
|
||
return
|
||
}
|
||
|
||
// 调用服务层获取用户信息
|
||
response, err := h.ssoService.GetUserInfo(c.Request.Context(), token)
|
||
if err != nil {
|
||
h.logger.Error("failed to get user info", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// handleLoginLogic 处理登录逻辑
|
||
func (h *SSOHandler) handleLoginLogic(c *gin.Context, req models.SSOLoginRequest) {
|
||
// 调用服务层初始化登录
|
||
response, err := h.ssoService.InitiateLogin(c.Request.Context())
|
||
if err != nil {
|
||
h.logger.Error("failed to initiate SSO login", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// handleCallbackLogic 处理回调逻辑
|
||
func (h *SSOHandler) handleCallbackLogic(c *gin.Context, req models.SSOCallbackRequest) {
|
||
// 验证参数
|
||
if req.State == "" {
|
||
// 参数缺失时,走登录逻辑
|
||
h.handleLoginLogic(c, models.SSOLoginRequest{})
|
||
return
|
||
}
|
||
|
||
// 调用服务层处理回调
|
||
response, err := h.ssoService.HandleCallback(c.Request.Context(), &req)
|
||
if err != nil {
|
||
h.logger.Error("failed to handle SSO callback", zap.Error(err))
|
||
// 回调失败时,走登录逻辑
|
||
h.handleLoginLogic(c, models.SSOLoginRequest{})
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, response)
|
||
}
|
||
|
||
// GetOnlineUsers 获取在线用户列表
|
||
func (h *SSOHandler) GetOnlineUsers(c *gin.Context) {
|
||
users, err := h.ssoService.GetOnlineUsers(c.Request.Context())
|
||
if err != nil {
|
||
h.logger.Error("failed to get online users", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, users)
|
||
}
|
||
|
||
// GetOnlineUserCount 获取在线用户数量
|
||
func (h *SSOHandler) GetOnlineUserCount(c *gin.Context) {
|
||
count, err := h.ssoService.GetOnlineUserCount(c.Request.Context())
|
||
if err != nil {
|
||
h.logger.Error("failed to get online user count", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, gin.H{"count": count})
|
||
}
|
||
|
||
// BatchLogout 批量登出用户
|
||
func (h *SSOHandler) BatchLogout(c *gin.Context) {
|
||
var req struct {
|
||
UserIDs []int `json:"user_ids" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
h.response.ValidateError(c, err)
|
||
return
|
||
}
|
||
|
||
err := h.ssoService.BatchUserLogout(c.Request.Context(), req.UserIDs)
|
||
if err != nil {
|
||
h.logger.Error("failed to batch logout users", zap.Error(err))
|
||
h.response.InternalServerError(c, err.Error())
|
||
return
|
||
}
|
||
|
||
h.response.Success(c, "批量登出成功")
|
||
}
|