194 lines
5.6 KiB
Go
194 lines
5.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"goalfymax-admin/internal/models"
|
||
"goalfymax-admin/internal/storage"
|
||
"goalfymax-admin/pkg/utils"
|
||
"net/http"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// RBACMiddleware 简化的RBAC权限中间件
|
||
type RBACMiddleware struct {
|
||
rbacStorage storage.RBACStorage
|
||
db *gorm.DB
|
||
logger *utils.Logger
|
||
}
|
||
|
||
// NewRBACMiddleware 创建RBAC中间件
|
||
func NewRBACMiddleware(rbacStorage storage.RBACStorage, db *gorm.DB, logger *utils.Logger) *RBACMiddleware {
|
||
return &RBACMiddleware{
|
||
rbacStorage: rbacStorage,
|
||
db: db,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// RequirePagePermission 检查页面权限
|
||
func (m *RBACMiddleware) RequirePagePermission(pagePath string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从上下文获取当前用户ID
|
||
userID, exists := c.Get("user_id")
|
||
if !exists {
|
||
m.logger.Error("未找到用户ID", zap.String("pagePath", pagePath))
|
||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 类型转换
|
||
userIDUint, ok := userID.(uint)
|
||
if !ok {
|
||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查用户是否有页面访问权限(基于角色)
|
||
hasPermission, err := m.rbacStorage.CheckUserRolePagePermission(userIDUint, pagePath)
|
||
if err != nil {
|
||
m.logger.Error("页面权限检查失败", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "权限检查失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
if !hasPermission {
|
||
m.logger.Warn("页面权限不足", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "页面权限不足"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
m.logger.Info("页面权限检查通过", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireRole 检查角色
|
||
func (m *RBACMiddleware) RequireRole(roleName string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从上下文获取当前用户ID
|
||
userID, exists := c.Get("user_id")
|
||
if !exists {
|
||
m.logger.Error("未找到用户ID", zap.String("role", roleName))
|
||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 类型转换:userID 是 uint
|
||
userIDUint, ok := userID.(uint)
|
||
if !ok {
|
||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 直接查询用户角色
|
||
var user models.User
|
||
err := m.db.Where("id = ?", userIDUint).First(&user).Error
|
||
if err != nil {
|
||
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 查询角色信息
|
||
var role models.Role
|
||
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
|
||
if err != nil {
|
||
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查用户是否有指定角色
|
||
hasRole := role.Name == roleName
|
||
|
||
if !hasRole {
|
||
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.String("role", roleName))
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.String("role", roleName))
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireAnyRole 检查任意角色
|
||
func (m *RBACMiddleware) RequireAnyRole(roleNames []string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 从上下文获取当前用户ID
|
||
userID, exists := c.Get("user_id")
|
||
if !exists {
|
||
m.logger.Error("未找到用户ID", zap.Strings("roles", roleNames))
|
||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 类型转换:userID 是 uint
|
||
userIDUint, ok := userID.(uint)
|
||
if !ok {
|
||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 直接查询用户角色
|
||
var user models.User
|
||
err := m.db.Where("id = ?", userIDUint).First(&user).Error
|
||
if err != nil {
|
||
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 查询角色信息
|
||
var role models.Role
|
||
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
|
||
if err != nil {
|
||
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查用户是否有任意一个角色
|
||
hasAnyRole := false
|
||
for _, roleName := range roleNames {
|
||
if role.Name == roleName {
|
||
hasAnyRole = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if !hasAnyRole {
|
||
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// GetUserAccessiblePages 获取用户可访问的页面列表
|
||
func (m *RBACMiddleware) GetUserAccessiblePages(userID uint) ([]string, error) {
|
||
return m.rbacStorage.GetUserRoleAccessiblePages(userID)
|
||
}
|