Files

156 lines
4.4 KiB
Go

package storage
import (
"time"
"goalfymax-admin/internal/models"
"gorm.io/gorm"
)
type InviteCodeStorage interface {
Create(inviteCode *models.InviteCode) error
GetByID(id uint) (*models.InviteCode, error)
GetByCode(code string) (*models.InviteCode, error)
List(req *models.InviteCodeListRequest) ([]models.InviteCode, int64, error)
Update(inviteCode *models.InviteCode) error
Delete(id uint) error
GetStatistics() (*models.InviteCodeStatistics, error)
IsExpired(inviteCode *models.InviteCode) bool
}
type inviteCodeStorage struct {
db *gorm.DB
}
func NewInviteCodeStorage() InviteCodeStorage {
return &inviteCodeStorage{db: DB}
}
func (s *inviteCodeStorage) Create(inviteCode *models.InviteCode) error {
// 若目标库缺少 is_used 列,则在插入时省略该列,避免 Unknown column 错误
if columnExistsIsUsed(s.db) {
return s.db.Create(inviteCode).Error
}
return s.db.Omit("is_used").Create(inviteCode).Error
}
func (s *inviteCodeStorage) GetByID(id uint) (*models.InviteCode, error) {
var inviteCode models.InviteCode
err := s.db.Where("deleted_at IS NULL").First(&inviteCode, id).Error
if err != nil {
return nil, err
}
return &inviteCode, nil
}
func (s *inviteCodeStorage) GetByCode(code string) (*models.InviteCode, error) {
var inviteCode models.InviteCode
err := s.db.Where("code = ? AND deleted_at IS NULL", code).First(&inviteCode).Error
if err != nil {
return nil, err
}
return &inviteCode, nil
}
// IsExpired 检查邀请码是否过期
func (s *inviteCodeStorage) IsExpired(inviteCode *models.InviteCode) bool {
if inviteCode.ExpiresAt == nil {
return false // 没有设置过期时间,永不过期
}
return time.Now().After(*inviteCode.ExpiresAt)
}
func (s *inviteCodeStorage) List(req *models.InviteCodeListRequest) ([]models.InviteCode, int64, error) {
var inviteCodes []models.InviteCode
var total int64
query := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL")
// 筛选条件
if req.Code != "" {
query = query.Where("code LIKE ?", "%"+req.Code+"%")
}
// 仅当存在 is_used 列时才应用过滤
if req.IsUsed != nil {
if columnExistsIsUsed(s.db) {
query = query.Where("is_used = ?", *req.IsUsed)
}
}
if req.StartTime != "" {
query = query.Where("created_at >= ?", req.StartTime)
}
if req.EndTime != "" {
query = query.Where("created_at <= ?", req.EndTime)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (req.Page - 1) * req.Size
err := query.Order("created_at DESC").Offset(offset).Limit(req.Size).Find(&inviteCodes).Error
if err != nil {
return nil, 0, err
}
return inviteCodes, total, nil
}
func (s *inviteCodeStorage) Update(inviteCode *models.InviteCode) error {
return s.db.Save(inviteCode).Error
}
func (s *inviteCodeStorage) Delete(id uint) error {
return s.db.Delete(&models.InviteCode{}, id).Error
}
func (s *inviteCodeStorage) GetStatistics() (*models.InviteCodeStatistics, error) {
var stats models.InviteCodeStatistics
// 总数
var total int64
if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL").Count(&total).Error; err != nil {
return nil, err
}
stats.Total = int(total)
// is_used 列可能不存在,存在时统计已使用/未使用
if columnExistsIsUsed(s.db) {
var used int64
if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND is_used = ?", true).Count(&used).Error; err != nil {
return nil, err
}
stats.Used = int(used)
var unused int64
if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND is_used = ?", false).Count(&unused).Error; err != nil {
return nil, err
}
stats.Unused = int(unused)
} else {
// 列不存在时,给出合理默认值
stats.Used = 0
stats.Unused = int(total)
}
// 今日新增
today := time.Now().Format("2006-01-02")
var todayCreated int64
if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND DATE(created_at) = ?", today).Count(&todayCreated).Error; err != nil {
return nil, err
}
stats.TodayCreated = int(todayCreated)
return &stats, nil
}
// columnExistsIsUsed 检查当前数据库中 admin_invite_codes 表是否存在 is_used 列
func columnExistsIsUsed(db *gorm.DB) bool {
var count int64
// 使用当前连接的数据库名
db.Raw("SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND COLUMN_NAME = 'is_used'", "admin_invite_codes").Scan(&count)
return count > 0
}