Files

125 lines
3.7 KiB
Go

package storage
import (
"goalfymax-admin/internal/models"
"gorm.io/gorm"
)
// PKCEStateStorage PKCE状态存储接口
type PKCEStateStorage interface {
Create(pkceState *models.PKCEState) error
GetByState(state string) (*models.PKCEState, error)
DeleteByState(state string) error
CleanExpired() error
}
// LoginInfoStorage 登录信息存储接口
type LoginInfoStorage interface {
Create(loginInfo *models.LoginInfo) error
GetByUserID(userID int) (*models.LoginInfo, error)
GetByUserIDAndUUID(userID int, uuid string) (*models.LoginInfo, error)
Update(loginInfo *models.LoginInfo) error
SetUserOffline(userID int) error
ListOnlineUsers() ([]*models.LoginInfo, error)
CountOnlineUsers() (int64, error)
DeleteByUserID(userID int) error
}
type pkceStateStorage struct {
db *gorm.DB
}
// NewPKCEStateStorage 创建PKCE状态存储实例
func NewPKCEStateStorage() PKCEStateStorage {
return &pkceStateStorage{db: DB}
}
// Create 创建PKCE状态
func (s *pkceStateStorage) Create(pkceState *models.PKCEState) error {
return s.db.Create(pkceState).Error
}
// GetByState 根据状态获取PKCE状态
func (s *pkceStateStorage) GetByState(state string) (*models.PKCEState, error) {
var pkceState models.PKCEState
err := s.db.Where("state = ?", state).First(&pkceState).Error
if err != nil {
return nil, err
}
return &pkceState, nil
}
// DeleteByState 根据状态删除PKCE状态
func (s *pkceStateStorage) DeleteByState(state string) error {
return s.db.Where("state = ?", state).Delete(&models.PKCEState{}).Error
}
// CleanExpired 清理过期的PKCE状态
func (s *pkceStateStorage) CleanExpired() error {
// 删除创建时间超过1小时的记录
return s.db.Where("created_at < ?", "NOW() - INTERVAL 1 HOUR").Delete(&models.PKCEState{}).Error
}
type loginInfoStorage struct {
db *gorm.DB
}
// NewLoginInfoStorage 创建登录信息存储实例
func NewLoginInfoStorage() LoginInfoStorage {
return &loginInfoStorage{db: DB}
}
// Create 创建登录信息
func (s *loginInfoStorage) Create(loginInfo *models.LoginInfo) error {
return s.db.Create(loginInfo).Error
}
// GetByUserID 根据用户ID获取登录信息
func (s *loginInfoStorage) GetByUserID(userID int) (*models.LoginInfo, error) {
var loginInfo models.LoginInfo
err := s.db.Where("user_id = ?", userID).First(&loginInfo).Error
if err != nil {
return nil, err
}
return &loginInfo, nil
}
// GetByUserIDAndUUID 根据用户ID和UUID获取登录信息
func (s *loginInfoStorage) GetByUserIDAndUUID(userID int, uuid string) (*models.LoginInfo, error) {
var loginInfo models.LoginInfo
err := s.db.Where("user_id = ? AND uuid = ?", userID, uuid).First(&loginInfo).Error
if err != nil {
return nil, err
}
return &loginInfo, nil
}
// Update 更新登录信息
func (s *loginInfoStorage) Update(loginInfo *models.LoginInfo) error {
return s.db.Save(loginInfo).Error
}
// SetUserOffline 设置用户离线
func (s *loginInfoStorage) SetUserOffline(userID int) error {
return s.db.Model(&models.LoginInfo{}).Where("user_id = ?", userID).Update("is_online", false).Error
}
// ListOnlineUsers 获取在线用户列表
func (s *loginInfoStorage) ListOnlineUsers() ([]*models.LoginInfo, error) {
var loginInfos []*models.LoginInfo
err := s.db.Where("is_online = ?", true).Find(&loginInfos).Error
return loginInfos, err
}
// CountOnlineUsers 获取在线用户数量
func (s *loginInfoStorage) CountOnlineUsers() (int64, error) {
var count int64
err := s.db.Model(&models.LoginInfo{}).Where("is_online = ?", true).Count(&count).Error
return count, err
}
// DeleteByUserID 根据用户ID删除登录信息
func (s *loginInfoStorage) DeleteByUserID(userID int) error {
return s.db.Where("user_id = ?", userID).Delete(&models.LoginInfo{}).Error
}