feat():learning后台管理项目初始化
This commit is contained in:
290
pkg/middleware/auth.go
Normal file
290
pkg/middleware/auth.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"goalfymax-admin/internal/models"
|
||||
"goalfymax-admin/internal/storage"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionManager 会话管理器接口
|
||||
type SessionManager interface {
|
||||
GetSession(ctx context.Context, sessionID string) (*models.Session, error)
|
||||
SetSession(ctx context.Context, sessionID string, session *models.Session) error
|
||||
DeleteSession(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
// MemorySessionManager 内存会话管理器
|
||||
type MemorySessionManager struct {
|
||||
sessions map[string]*models.Session
|
||||
}
|
||||
|
||||
// NewMemorySessionManager 创建内存会话管理器
|
||||
func NewMemorySessionManager() *MemorySessionManager {
|
||||
return &MemorySessionManager{
|
||||
sessions: make(map[string]*models.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MemorySessionManager) GetSession(ctx context.Context, sessionID string) (*models.Session, error) {
|
||||
session, exists := m.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(m.sessions, sessionID)
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (m *MemorySessionManager) SetSession(ctx context.Context, sessionID string, session *models.Session) error {
|
||||
m.sessions[sessionID] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemorySessionManager) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
delete(m.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthMiddleware 认证中间件
|
||||
type AuthMiddleware struct {
|
||||
client *SSOClient
|
||||
sessionManager SessionManager
|
||||
loginURL string
|
||||
validationMode string // "sso" 或 "jwt"
|
||||
}
|
||||
|
||||
// NewAuthMiddleware 创建认证中间件
|
||||
func NewAuthMiddleware(client *SSOClient, sessionManager SessionManager, loginURL string) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
client: client,
|
||||
sessionManager: sessionManager,
|
||||
loginURL: loginURL,
|
||||
validationMode: "sso", // 默认使用SSO验证模式
|
||||
}
|
||||
}
|
||||
|
||||
// SetValidationMode 设置验证模式
|
||||
func (m *AuthMiddleware) SetValidationMode(mode string) {
|
||||
m.validationMode = mode
|
||||
}
|
||||
|
||||
// RequireAuth 要求认证的中间件
|
||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization头获取访问令牌
|
||||
var token string = ""
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
authHeader = c.Query("token")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "Authorization header is required",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
token = strings.Trim(authHeader, " ")
|
||||
} else {
|
||||
// 提取Bearer令牌
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
token = authHeader[7:]
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": "Invalid authorization header",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 直接调用第三方SSO服务验证token并获取用户信息
|
||||
// 不再进行本地JWT验证,而是通过调用第三方API来验证token的有效性
|
||||
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"message": fmt.Sprintf("Invalid token or failed to get user info: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析SSO用户ID为数值
|
||||
userID, err := strconv.ParseUint(userInfo.Sub, 10, 32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_user_id",
|
||||
"message": "Invalid user ID in token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
userIDUint := uint(userID)
|
||||
|
||||
// 查找或创建本地用户
|
||||
user, err := m.findOrCreateUser(userIDUint, userInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "user_creation_failed",
|
||||
"message": "Failed to create or find user",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user", userInfo)
|
||||
c.Set("user_id", userIDUint) // 使用转换后的用户ID
|
||||
c.Set("local_user", user) // 本地用户对象
|
||||
c.Set("token", token)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuth 可选认证的中间件
|
||||
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization头获取访问令牌
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取Bearer令牌
|
||||
token := ""
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
token = authHeader[7:]
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 直接调用第三方SSO服务验证token并获取用户信息
|
||||
userInfo, err := m.client.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user", userInfo)
|
||||
c.Set("token", token)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserFromContext 从上下文中获取用户信息
|
||||
func GetUserFromContext(c *gin.Context) (*models.UserInfo, bool) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
userInfo, ok := user.(*models.UserInfo)
|
||||
return userInfo, ok
|
||||
}
|
||||
|
||||
// GetUserIDFromContext 从上下文中获取用户ID
|
||||
func GetUserIDFromContext(c *gin.Context) (int, bool) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
userIDInt, ok := userID.(int)
|
||||
return userIDInt, ok
|
||||
}
|
||||
|
||||
// GetTokenFromContext 从上下文中获取令牌
|
||||
func GetTokenFromContext(c *gin.Context) (string, bool) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tokenStr, ok := token.(string)
|
||||
return tokenStr, ok
|
||||
}
|
||||
|
||||
// findOrCreateUser 查找或创建用户
|
||||
func (m *AuthMiddleware) findOrCreateUser(userID uint, userInfo *models.UserInfo) (*models.User, error) {
|
||||
// 尝试查找现有用户
|
||||
var user models.User
|
||||
err := storage.DB.Where("id = ?", userID).First(&user).Error
|
||||
|
||||
if err == nil {
|
||||
// 用户存在,更新登录信息
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
user.LoginCount++
|
||||
|
||||
// 更新用户信息(如果SSO信息有变化)
|
||||
if userInfo.Name != "" && user.Nickname != userInfo.Name {
|
||||
user.Nickname = userInfo.Name
|
||||
}
|
||||
if userInfo.Email != "" && user.Email != userInfo.Email {
|
||||
user.Email = userInfo.Email
|
||||
}
|
||||
|
||||
err = storage.DB.Save(&user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新用户信息失败: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
|
||||
// 用户不存在,创建新用户
|
||||
now := time.Now()
|
||||
user = models.User{
|
||||
BaseModel: models.BaseModel{
|
||||
ID: userID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
Username: userInfo.PreferredUsername,
|
||||
Email: userInfo.Email,
|
||||
Nickname: userInfo.Name,
|
||||
Status: 1, // 默认启用
|
||||
SSOProvider: "default", // 可以根据实际情况设置
|
||||
LastLoginAt: &now,
|
||||
LoginCount: 1,
|
||||
}
|
||||
|
||||
// 如果PreferredUsername为空,使用Email作为用户名
|
||||
if user.Username == "" {
|
||||
user.Username = userInfo.Email
|
||||
}
|
||||
|
||||
err = storage.DB.Create(&user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建用户失败: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
193
pkg/middleware/rbac.go
Normal file
193
pkg/middleware/rbac.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"goalfymax-admin/internal/models"
|
||||
"goalfymax-admin/internal/storage"
|
||||
"goalfymax-admin/pkg/utils"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RBACMiddleware 简化的RBAC权限中间件
|
||||
type RBACMiddleware struct {
|
||||
rbacStorage storage.RBACStorage
|
||||
db *gorm.DB
|
||||
logger *utils.Logger
|
||||
}
|
||||
|
||||
// NewRBACMiddleware 创建RBAC中间件
|
||||
func NewRBACMiddleware(rbacStorage storage.RBACStorage, db *gorm.DB, logger *utils.Logger) *RBACMiddleware {
|
||||
return &RBACMiddleware{
|
||||
rbacStorage: rbacStorage,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RequirePagePermission 检查页面权限
|
||||
func (m *RBACMiddleware) RequirePagePermission(pagePath string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从上下文获取当前用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
m.logger.Error("未找到用户ID", zap.String("pagePath", pagePath))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 类型转换
|
||||
userIDUint, ok := userID.(uint)
|
||||
if !ok {
|
||||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否有页面访问权限(基于角色)
|
||||
hasPermission, err := m.rbacStorage.CheckUserRolePagePermission(userIDUint, pagePath)
|
||||
if err != nil {
|
||||
m.logger.Error("页面权限检查失败", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "权限检查失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if !hasPermission {
|
||||
m.logger.Warn("页面权限不足", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "页面权限不足"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Info("页面权限检查通过", zap.Uint("userID", userIDUint), zap.String("pagePath", pagePath))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole 检查角色
|
||||
func (m *RBACMiddleware) RequireRole(roleName string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从上下文获取当前用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
m.logger.Error("未找到用户ID", zap.String("role", roleName))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 类型转换:userID 是 uint
|
||||
userIDUint, ok := userID.(uint)
|
||||
if !ok {
|
||||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 直接查询用户角色
|
||||
var user models.User
|
||||
err := m.db.Where("id = ?", userIDUint).First(&user).Error
|
||||
if err != nil {
|
||||
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 查询角色信息
|
||||
var role models.Role
|
||||
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
|
||||
if err != nil {
|
||||
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否有指定角色
|
||||
hasRole := role.Name == roleName
|
||||
|
||||
if !hasRole {
|
||||
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.String("role", roleName))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.String("role", roleName))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyRole 检查任意角色
|
||||
func (m *RBACMiddleware) RequireAnyRole(roleNames []string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从上下文获取当前用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
m.logger.Error("未找到用户ID", zap.Strings("roles", roleNames))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 类型转换:userID 是 uint
|
||||
userIDUint, ok := userID.(uint)
|
||||
if !ok {
|
||||
m.logger.Error("无效的用户ID类型", zap.Any("userID", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的用户ID"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 直接查询用户角色
|
||||
var user models.User
|
||||
err := m.db.Where("id = ?", userIDUint).First(&user).Error
|
||||
if err != nil {
|
||||
m.logger.Error("获取用户信息失败", zap.Uint("userID", userIDUint), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取用户信息失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 查询角色信息
|
||||
var role models.Role
|
||||
err = m.db.Where("id = ?", user.RoleID).First(&role).Error
|
||||
if err != nil {
|
||||
m.logger.Error("获取角色信息失败", zap.Uint("roleID", user.RoleID), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取角色信息失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否有任意一个角色
|
||||
hasAnyRole := false
|
||||
for _, roleName := range roleNames {
|
||||
if role.Name == roleName {
|
||||
hasAnyRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasAnyRole {
|
||||
m.logger.Warn("角色不足", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "角色不足"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Info("角色检查通过", zap.Uint("userID", userIDUint), zap.Strings("roles", roleNames))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAccessiblePages 获取用户可访问的页面列表
|
||||
func (m *RBACMiddleware) GetUserAccessiblePages(userID uint) ([]string, error) {
|
||||
return m.rbacStorage.GetUserRoleAccessiblePages(userID)
|
||||
}
|
||||
476
pkg/middleware/sso_client.go
Normal file
476
pkg/middleware/sso_client.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"goalfymax-admin/internal/models"
|
||||
"goalfymax-admin/pkg/utils"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// SSOClient SSO客户端
|
||||
type SSOClient struct {
|
||||
config *models.SSOConfig
|
||||
http *http.Client
|
||||
oauth2 *oauth2.Config
|
||||
logger *utils.Logger
|
||||
}
|
||||
|
||||
// NewSSOClient 创建新的SSO客户端
|
||||
func NewSSOClient(config *models.SSOConfig, logger *utils.Logger) *SSOClient {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
}
|
||||
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURI,
|
||||
Scopes: strings.Split(config.Scope, " "),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("%s/oauth2/authorize", config.SSOServerURL),
|
||||
TokenURL: fmt.Sprintf("%s/oauth2/token", config.SSOServerURL),
|
||||
},
|
||||
}
|
||||
|
||||
return &SSOClient{
|
||||
config: config,
|
||||
http: httpClient,
|
||||
oauth2: oauth2Config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SSOClient) GetServerUrl() string {
|
||||
return c.config.SSOServerURL
|
||||
}
|
||||
|
||||
func (c *SSOClient) GetRedirectUrl() string {
|
||||
return c.config.RedirectURI
|
||||
}
|
||||
|
||||
// GetAuthorizationURL 获取授权URL
|
||||
func (c *SSOClient) GetAuthorizationURL(state string) (string, string, error) {
|
||||
// 生成PKCE挑战
|
||||
codeVerifier, err := generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
// 构建授权URL
|
||||
authURL := c.oauth2.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier))
|
||||
|
||||
return authURL, codeVerifier, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForToken 使用授权码交换令牌
|
||||
func (c *SSOClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier string) (*models.TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("client_id", c.config.ClientID)
|
||||
data.Set("client_secret", c.config.ClientSecret)
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", c.config.RedirectURI)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("token exchange failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
var tokenResp models.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新访问令牌
|
||||
func (c *SSOClient) RefreshToken(ctx context.Context, refreshToken string) (*models.TokenResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("client_id", c.config.ClientID)
|
||||
data.Set("client_secret", c.config.ClientSecret)
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/token", c.config.SSOServerURL), strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("token refresh failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
var tokenResp models.TokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *SSOClient) GetUserInfo(ctx context.Context, accessToken string) (*models.UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/userinfo", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorResp models.ErrorResponse
|
||||
if err := json.Unmarshal(body, &errorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("get user info failed: %s - %s", errorResp.Error, errorResp.Message)
|
||||
}
|
||||
|
||||
// 首先解析为 SSOUserInfo,它可以接受任何类型的 roles 字段
|
||||
var ssoUserInfo models.SSOUserInfo
|
||||
if err := json.Unmarshal(body, &ssoUserInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info response: %w", err)
|
||||
}
|
||||
|
||||
// 转换为应用程序内部使用的 UserInfo
|
||||
userInfo := &models.UserInfo{
|
||||
Sub: ssoUserInfo.Sub,
|
||||
Name: ssoUserInfo.Name,
|
||||
Email: ssoUserInfo.Email,
|
||||
PreferredUsername: ssoUserInfo.PreferredUsername,
|
||||
}
|
||||
|
||||
// 规范化 roles 字段:支持 string、[]string、[]interface{}
|
||||
if ssoUserInfo.Roles != nil {
|
||||
switch v := ssoUserInfo.Roles.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: v})
|
||||
}
|
||||
case []string:
|
||||
for _, name := range v {
|
||||
if name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, r := range v {
|
||||
if name, ok := r.(string); ok && name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
default:
|
||||
// 尝试通用 JSON 数组字符串
|
||||
if b, err := json.Marshal(v); err == nil {
|
||||
var arr []string
|
||||
if err := json.Unmarshal(b, &arr); err == nil {
|
||||
for _, name := range arr {
|
||||
if name != "" {
|
||||
userInfo.Roles = append(userInfo.Roles, models.Role{Name: name})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// ValidateTokenWithSSO 通过调用第三方SSO服务验证token
|
||||
func (c *SSOClient) ValidateTokenWithSSO(ctx context.Context, accessToken string) error {
|
||||
// 尝试调用用户信息接口来验证token
|
||||
_, err := c.GetUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌(保留原有方法,但标记为不推荐使用)
|
||||
func (c *SSOClient) ValidateToken(ctx context.Context, accessToken string) (*jwt.Token, error) {
|
||||
// 获取JWKS
|
||||
jwks, err := c.GetJWKS(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
// 解析JWT令牌
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名算法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// 从JWKS中获取公钥
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("kid not found in token header")
|
||||
}
|
||||
|
||||
// 在JWKS中查找对应的公钥
|
||||
publicKey, err := c.findPublicKey(jwks, kid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
// 验证令牌是否有效
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// 验证令牌的声明
|
||||
if err := c.validateTokenClaims(token); err != nil {
|
||||
return nil, fmt.Errorf("invalid token claims: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// findPublicKey 在JWKS中查找对应的公钥
|
||||
func (c *SSOClient) findPublicKey(jwks *models.JWKS, kid string) (interface{}, error) {
|
||||
for _, key := range jwks.Keys {
|
||||
if keyKid, ok := key["kid"].(string); ok && keyKid == kid {
|
||||
// 解析RSA公钥
|
||||
if n, ok := key["n"].(string); ok {
|
||||
if e, ok := key["e"].(string); ok {
|
||||
return c.parseRSAPublicKey(n, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("public key not found for kid: %s", kid)
|
||||
}
|
||||
|
||||
// parseRSAPublicKey 解析RSA公钥
|
||||
func (c *SSOClient) parseRSAPublicKey(n, e string) (interface{}, error) {
|
||||
// 解码Base64URL编码的模数和指数
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode modulus: %w", err)
|
||||
}
|
||||
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(e)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode exponent: %w", err)
|
||||
}
|
||||
|
||||
// 创建RSA公钥
|
||||
modulus := new(big.Int).SetBytes(nBytes)
|
||||
exponent := new(big.Int).SetBytes(eBytes)
|
||||
|
||||
publicKey := &rsa.PublicKey{
|
||||
N: modulus,
|
||||
E: int(exponent.Int64()),
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// validateTokenClaims 验证令牌声明
|
||||
func (c *SSOClient) validateTokenClaims(token *jwt.Token) error {
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
// 验证发行者
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
if iss != c.config.SSOServerURL {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证受众
|
||||
if aud, ok := claims["aud"].(string); ok {
|
||||
if aud != c.config.ResourceAud {
|
||||
return fmt.Errorf("invalid audience: %s", aud)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Unix(int64(exp), 0).Before(time.Now()) {
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证生效时间
|
||||
if nbf, ok := claims["nbf"].(float64); ok {
|
||||
if time.Unix(int64(nbf), 0).After(time.Now()) {
|
||||
return fmt.Errorf("token not yet valid")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOpenIDConfiguration 获取OpenID配置
|
||||
func (c *SSOClient) GetOpenIDConfiguration(ctx context.Context) (*models.OpenIDConfiguration, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/.well-known/openid-configuration", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get OpenID configuration: %s", resp.Status)
|
||||
}
|
||||
|
||||
var config models.OpenIDConfiguration
|
||||
if err := json.Unmarshal(body, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse OpenID configuration: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// GetJWKS 获取JWKS
|
||||
func (c *SSOClient) GetJWKS(ctx context.Context) (*models.JWKS, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/oauth2/jwks.json", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get JWKS: %s", resp.Status)
|
||||
}
|
||||
|
||||
var jwks models.JWKS
|
||||
if err := json.Unmarshal(body, &jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWKS: %w", err)
|
||||
}
|
||||
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// Logout 登出
|
||||
func (c *SSOClient) Logout(ctx context.Context, accessToken string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/oauth2/logout-pre", c.config.SSOServerURL), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to make request", zap.Error(err))
|
||||
return fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
c.logger.Error("failed to logout", zap.String("status", resp.Status), zap.String("body", string(body)))
|
||||
return fmt.Errorf("logout failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成PKCE代码验证器
|
||||
func generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
38
pkg/redis/redis.go
Normal file
38
pkg/redis/redis.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"goalfymax-admin/internal/config"
|
||||
"log"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Client 是一个 Redis 客户端的封装
|
||||
type Client struct {
|
||||
Rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewClient 创建一个新的 Redis 客户端实例
|
||||
func NewClient(cfg config.RedisConfig) (*Client, error) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
if _, err := rdb.Ping(ctx).Result(); err != nil {
|
||||
rdb.Close()
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully connected to Redis")
|
||||
return &Client{Rdb: rdb}, nil
|
||||
}
|
||||
|
||||
// Close 关闭 Redis 客户端连接
|
||||
func (c *Client) Close() error {
|
||||
return c.Rdb.Close()
|
||||
}
|
||||
133
pkg/utils/README.md
Normal file
133
pkg/utils/README.md
Normal file
@@ -0,0 +1,133 @@
|
||||
# 工具包
|
||||
|
||||
本模块提供各种实用工具函数和类。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 加密工具(MD5、SHA256、密码哈希)
|
||||
- JWT token管理
|
||||
- 统一响应处理
|
||||
- 数据验证
|
||||
- 日志管理
|
||||
|
||||
## 模块结构
|
||||
|
||||
```
|
||||
utils/
|
||||
├── crypto.go # 加密工具
|
||||
├── jwt.go # JWT管理
|
||||
├── response.go # 响应处理
|
||||
├── validator.go # 数据验证
|
||||
├── logger.go # 日志管理
|
||||
└── README.md # 说明文档
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 加密工具
|
||||
|
||||
```go
|
||||
import "goalfymax-admin/pkg/utils"
|
||||
|
||||
// MD5哈希
|
||||
hash := utils.MD5Hash("password")
|
||||
|
||||
// SHA256哈希
|
||||
hash := utils.SHA256Hash("password")
|
||||
|
||||
// 生成盐值
|
||||
salt, err := utils.GenerateSalt()
|
||||
|
||||
// 哈希密码
|
||||
hashedPassword := utils.HashPassword("password", salt)
|
||||
|
||||
// 验证密码
|
||||
isValid := utils.VerifyPassword("password", salt, hashedPassword)
|
||||
```
|
||||
|
||||
### JWT管理
|
||||
|
||||
```go
|
||||
// 创建JWT管理器
|
||||
jwtManager := utils.NewJWTManager("your-secret-key")
|
||||
|
||||
// 生成token
|
||||
token, err := jwtManager.GenerateToken(1, "admin", "admin")
|
||||
|
||||
// 解析token
|
||||
claims, err := jwtManager.ParseToken(token)
|
||||
|
||||
// 刷新token
|
||||
newToken, err := jwtManager.RefreshToken(token)
|
||||
```
|
||||
|
||||
### 响应处理
|
||||
|
||||
```go
|
||||
// 创建响应实例
|
||||
resp := utils.NewResponse()
|
||||
|
||||
// 成功响应
|
||||
resp.Success(c, data)
|
||||
|
||||
// 错误响应
|
||||
resp.Error(c, 400, "参数错误")
|
||||
resp.BadRequest(c, "请求参数错误")
|
||||
resp.Unauthorized(c, "未授权")
|
||||
resp.Forbidden(c, "禁止访问")
|
||||
resp.NotFound(c, "资源不存在")
|
||||
resp.InternalServerError(c, "服务器内部错误")
|
||||
|
||||
// 分页响应
|
||||
resp.Page(c, data, total, page, size)
|
||||
```
|
||||
|
||||
### 数据验证
|
||||
|
||||
```go
|
||||
// 创建验证器
|
||||
validator := utils.NewValidator()
|
||||
|
||||
// 验证邮箱
|
||||
isValid := validator.IsEmail("user@example.com")
|
||||
|
||||
// 验证手机号
|
||||
isValid := validator.IsPhone("13800138000")
|
||||
|
||||
// 验证用户名
|
||||
isValid := validator.IsUsername("admin")
|
||||
|
||||
// 验证密码强度
|
||||
isValid := validator.IsPassword("password123")
|
||||
|
||||
// 验证URL
|
||||
isValid := validator.IsURL("https://example.com")
|
||||
|
||||
// 检查是否为空
|
||||
isEmpty := validator.IsEmpty("")
|
||||
|
||||
// 验证角色
|
||||
isValid := validator.IsValidRole("admin")
|
||||
```
|
||||
|
||||
### 日志管理
|
||||
|
||||
```go
|
||||
// 创建日志实例
|
||||
logger, err := utils.NewLogger("info", "json", "stdout")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
logger.Info("用户登录", zap.String("username", "admin"))
|
||||
logger.Error("登录失败", zap.String("error", "密码错误"))
|
||||
|
||||
// 添加字段
|
||||
logger.WithField("user_id", 1).Info("用户操作")
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"user_id": 1,
|
||||
"action": "login",
|
||||
}).Info("用户登录")
|
||||
```
|
||||
|
||||
55
pkg/utils/crypto.go
Normal file
55
pkg/utils/crypto.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// MD5Hash 计算MD5哈希
|
||||
func MD5Hash(text string) string {
|
||||
hash := md5.Sum([]byte(text))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// SHA256Hash 计算SHA256哈希
|
||||
func SHA256Hash(text string) string {
|
||||
hash := sha256.Sum256([]byte(text))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// GenerateRandomString 生成指定长度的随机字符串
|
||||
func GenerateRandomString(length int) (string, error) {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GenerateSalt 生成盐值
|
||||
func GenerateSalt() (string, error) {
|
||||
return GenerateRandomString(16)
|
||||
}
|
||||
|
||||
// HashPassword 使用盐值哈希密码
|
||||
func HashPassword(password, salt string) string {
|
||||
return SHA256Hash(password + salt)
|
||||
}
|
||||
|
||||
// VerifyPassword 验证密码
|
||||
func VerifyPassword(password, salt, hash string) bool {
|
||||
return HashPassword(password, salt) == hash
|
||||
}
|
||||
|
||||
// GenerateToken 生成随机token
|
||||
func GenerateToken() (string, error) {
|
||||
return GenerateRandomString(32)
|
||||
}
|
||||
80
pkg/utils/jwt.go
Normal file
80
pkg/utils/jwt.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// JWTClaims JWT声明
|
||||
type JWTClaims struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// JWTManager JWT管理器
|
||||
type JWTManager struct {
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewJWTManager 创建JWT管理器
|
||||
func NewJWTManager(secretKey string) *JWTManager {
|
||||
return &JWTManager{
|
||||
secretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (j *JWTManager) GenerateToken(userID uint, username string) (string, error) {
|
||||
claims := JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Role: "", // 不再使用role字段
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(j.secretKey))
|
||||
}
|
||||
|
||||
// ParseToken 解析JWT token
|
||||
func (j *JWTManager) ParseToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
// RefreshToken 刷新token
|
||||
func (j *JWTManager) RefreshToken(tokenString string) (string, error) {
|
||||
claims, err := j.ParseToken(tokenString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 检查token是否即将过期(剩余时间少于1小时)
|
||||
if time.Until(claims.ExpiresAt.Time) < time.Hour {
|
||||
return j.GenerateToken(claims.UserID, claims.Username)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
98
pkg/utils/logger.go
Normal file
98
pkg/utils/logger.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Logger 日志管理器
|
||||
type Logger struct {
|
||||
*zap.Logger
|
||||
}
|
||||
|
||||
// NewLogger 创建日志实例
|
||||
func NewLogger(level, format, output string) (*Logger, error) {
|
||||
// 设置日志级别
|
||||
var zapLevel zapcore.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
zapLevel = zapcore.DebugLevel
|
||||
case "info":
|
||||
zapLevel = zapcore.InfoLevel
|
||||
case "warn":
|
||||
zapLevel = zapcore.WarnLevel
|
||||
case "error":
|
||||
zapLevel = zapcore.ErrorLevel
|
||||
default:
|
||||
zapLevel = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
// 设置编码格式
|
||||
var encoderConfig zapcore.EncoderConfig
|
||||
if format == "json" {
|
||||
encoderConfig = zap.NewProductionEncoderConfig()
|
||||
} else {
|
||||
encoderConfig = zap.NewDevelopmentEncoderConfig()
|
||||
}
|
||||
|
||||
// 设置时间格式
|
||||
encoderConfig.TimeKey = "timestamp"
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
// 创建配置
|
||||
config := zap.Config{
|
||||
Level: zap.NewAtomicLevelAt(zapLevel),
|
||||
Development: format != "json",
|
||||
Encoding: format,
|
||||
EncoderConfig: encoderConfig,
|
||||
OutputPaths: []string{output},
|
||||
ErrorOutputPaths: []string{output},
|
||||
}
|
||||
|
||||
// 创建logger
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Logger{Logger: logger}, nil
|
||||
}
|
||||
|
||||
// Info 记录信息日志
|
||||
func (l *Logger) Info(msg string, fields ...zap.Field) {
|
||||
l.Logger.Info(msg, fields...)
|
||||
}
|
||||
|
||||
// Debug 记录调试日志
|
||||
func (l *Logger) Debug(msg string, fields ...zap.Field) {
|
||||
l.Logger.Debug(msg, fields...)
|
||||
}
|
||||
|
||||
// Warn 记录警告日志
|
||||
func (l *Logger) Warn(msg string, fields ...zap.Field) {
|
||||
l.Logger.Warn(msg, fields...)
|
||||
}
|
||||
|
||||
// Error 记录错误日志
|
||||
func (l *Logger) Error(msg string, fields ...zap.Field) {
|
||||
l.Logger.Error(msg, fields...)
|
||||
}
|
||||
|
||||
// Fatal 记录致命错误日志
|
||||
func (l *Logger) Fatal(msg string, fields ...zap.Field) {
|
||||
l.Logger.Fatal(msg, fields...)
|
||||
}
|
||||
|
||||
// WithField 添加字段
|
||||
func (l *Logger) WithField(key string, value interface{}) *Logger {
|
||||
return &Logger{Logger: l.Logger.With(zap.Any(key, value))}
|
||||
}
|
||||
|
||||
// WithFields 添加多个字段
|
||||
func (l *Logger) WithFields(fields map[string]interface{}) *Logger {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
zapFields = append(zapFields, zap.Any(k, v))
|
||||
}
|
||||
return &Logger{Logger: l.Logger.With(zapFields...)}
|
||||
}
|
||||
61
pkg/utils/response.go
Normal file
61
pkg/utils/response.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"goalfymax-admin/internal/models"
|
||||
)
|
||||
|
||||
// Response 统一响应处理
|
||||
type Response struct{}
|
||||
|
||||
// Success 成功响应
|
||||
func (r *Response) Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, models.NewSuccessResponse(data))
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func (r *Response) Error(c *gin.Context, code int, message string) {
|
||||
c.JSON(code, models.NewResponse(code, message, nil))
|
||||
}
|
||||
|
||||
// BadRequest 400错误
|
||||
func (r *Response) BadRequest(c *gin.Context, message string) {
|
||||
r.Error(c, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized 401错误
|
||||
func (r *Response) Unauthorized(c *gin.Context, message string) {
|
||||
r.Error(c, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden 403错误
|
||||
func (r *Response) Forbidden(c *gin.Context, message string) {
|
||||
r.Error(c, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound 404错误
|
||||
func (r *Response) NotFound(c *gin.Context, message string) {
|
||||
r.Error(c, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// InternalServerError 500错误
|
||||
func (r *Response) InternalServerError(c *gin.Context, message string) {
|
||||
r.Error(c, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// Page 分页响应
|
||||
func (r *Response) Page(c *gin.Context, data interface{}, total int64, page, size int) {
|
||||
c.JSON(http.StatusOK, models.NewPageResponse(data, total, page, size))
|
||||
}
|
||||
|
||||
// ValidateError 验证错误响应
|
||||
func (r *Response) ValidateError(c *gin.Context, err error) {
|
||||
r.BadRequest(c, err.Error())
|
||||
}
|
||||
|
||||
// NewResponse 创建响应实例
|
||||
func NewResponse() *Response {
|
||||
return &Response{}
|
||||
}
|
||||
72
pkg/utils/validator.go
Normal file
72
pkg/utils/validator.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Validator 验证器
|
||||
type Validator struct{}
|
||||
|
||||
// IsEmail 验证邮箱格式
|
||||
func (v *Validator) IsEmail(email string) bool {
|
||||
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||
matched, _ := regexp.MatchString(pattern, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// IsPhone 验证手机号格式
|
||||
func (v *Validator) IsPhone(phone string) bool {
|
||||
pattern := `^1[3-9]\d{9}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
// IsUsername 验证用户名格式
|
||||
func (v *Validator) IsUsername(username string) bool {
|
||||
// 用户名只能包含字母、数字、下划线,长度3-20
|
||||
pattern := `^[a-zA-Z0-9_]{3,20}$`
|
||||
matched, _ := regexp.MatchString(pattern, username)
|
||||
return matched
|
||||
}
|
||||
|
||||
// IsPassword 验证密码强度
|
||||
func (v *Validator) IsPassword(password string) bool {
|
||||
// 密码至少6位,包含字母和数字
|
||||
if len(password) < 6 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasLetter := regexp.MustCompile(`[a-zA-Z]`).MatchString(password)
|
||||
hasNumber := regexp.MustCompile(`[0-9]`).MatchString(password)
|
||||
|
||||
return hasLetter && hasNumber
|
||||
}
|
||||
|
||||
// IsURL 验证URL格式
|
||||
func (v *Validator) IsURL(url string) bool {
|
||||
pattern := `^https?://[^\s/$.?#].[^\s]*$`
|
||||
matched, _ := regexp.MatchString(pattern, url)
|
||||
return matched
|
||||
}
|
||||
|
||||
// IsEmpty 检查字符串是否为空
|
||||
func (v *Validator) IsEmpty(str string) bool {
|
||||
return strings.TrimSpace(str) == ""
|
||||
}
|
||||
|
||||
// IsValidRole 验证角色名称
|
||||
func (v *Validator) IsValidRole(role string) bool {
|
||||
validRoles := []string{"admin", "user", "guest"}
|
||||
for _, validRole := range validRoles {
|
||||
if role == validRole {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NewValidator 创建验证器实例
|
||||
func NewValidator() *Validator {
|
||||
return &Validator{}
|
||||
}
|
||||
Reference in New Issue
Block a user