package middleware import ( "goalfymax-admin/internal/models" "goalfymax-admin/internal/storage" "goalfymax-admin/pkg/utils" "net/http" "github.com/gin-gonic/gin" "go.uber.org/zap" "gorm.io/gorm" ) // RBACMiddleware 简化的RBAC权限中间件 type RBACMiddleware struct { rbacStorage storage.RBACStorage db *gorm.DB logger *utils.Logger } // NewRBACMiddleware 创建RBAC中间件 func NewRBACMiddleware(rbacStorage storage.RBACStorage, db *gorm.DB, logger *utils.Logger) *RBACMiddleware { return &RBACMiddleware{ rbacStorage: rbacStorage, db: db, logger: logger, } } // RequirePagePermission 检查页面权限 func (m *RBACMiddleware) RequirePagePermission(pagePath string) gin.HandlerFunc { return func(c *gin.Context) { // 从上下文获取当前用户ID userID, exists := c.Get("user_id") if !exists { m.logger.Error("未找到用户ID", zap.String("pagePath", pagePath)) c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"}) c.Abort() return } // 类型转换 userIDUint, ok := userID.(uint) if !ok { m.logger.Error("无效的用户ID类型", zap.Any("userID", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"}) c.Abort() return } // 检查用户是否有页面访问权限(基于角色) hasPermission, err := m.rbacStorage.CheckUserRolePagePermission(userIDUint, pagePath) if err != nil { m.logger.Error("页面权限检查失败", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath), zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "权限检查失败"}) c.Abort() return } if !hasPermission { m.logger.Warn("页面权限不足", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath)) c.JSON(http.StatusForbidden, gin.H{"error": "页面权限不足"}) c.Abort() return } m.logger.Info("页面权限检查通过", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath)) c.Next() } } // RequireRole 检查角色 func (m *RBACMiddleware) RequireRole(roleName string) gin.HandlerFunc { return func(c *gin.Context) { // 从上下文获取当前用户ID userID, exists := c.Get("user_id") if !exists { m.logger.Error("未找到用户ID", zap.String("role", roleName)) c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"}) c.Abort() return } // 类型转换:userID 是 uint userIDUint, ok := userID.(uint) if !ok { m.logger.Error("无效的用户ID类型", zap.Any("userID", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"}) c.Abort() return } // 直接查询用户角色 var user models.User err := m.db.Where("id = ?", userIDUint).First(&user).Error if err != nil { m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"}) c.Abort() return } // 查询角色信息 var role models.Role err = m.db.Where("id = ?", user.RoleID).First(&role).Error if err != nil { m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"}) c.Abort() return } // 检查用户是否有指定角色 hasRole := role.Name == roleName if !hasRole { m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.String("role", roleName)) c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"}) c.Abort() return } m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.String("role", roleName)) c.Next() } } // RequireAnyRole 检查任意角色 func (m *RBACMiddleware) RequireAnyRole(roleNames []string) gin.HandlerFunc { return func(c *gin.Context) { // 从上下文获取当前用户ID userID, exists := c.Get("user_id") if !exists { m.logger.Error("未找到用户ID", zap.Strings("roles", roleNames)) c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"}) c.Abort() return } // 类型转换:userID 是 uint userIDUint, ok := userID.(uint) if !ok { m.logger.Error("无效的用户ID类型", zap.Any("userID", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"}) c.Abort() return } // 直接查询用户角色 var user models.User err := m.db.Where("id = ?", userIDUint).First(&user).Error if err != nil { m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"}) c.Abort() return } // 查询角色信息 var role models.Role err = m.db.Where("id = ?", user.RoleID).First(&role).Error if err != nil { m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"}) c.Abort() return } // 检查用户是否有任意一个角色 hasAnyRole := false for _, roleName := range roleNames { if role.Name == roleName { hasAnyRole = true break } } if !hasAnyRole { m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames)) c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"}) c.Abort() return } m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames)) c.Next() } } // GetUserAccessiblePages 获取用户可访问的页面列表 func (m *RBACMiddleware) GetUserAccessiblePages(userID uint) ([]string, error) { return m.rbacStorage.GetUserRoleAccessiblePages(userID) }