package storage import ( "time" "goalfymax-admin/internal/models" "gorm.io/gorm" ) type InviteCodeStorage interface { Create(inviteCode *models.InviteCode) error GetByID(id uint) (*models.InviteCode, error) GetByCode(code string) (*models.InviteCode, error) List(req *models.InviteCodeListRequest) ([]models.InviteCode, int64, error) Update(inviteCode *models.InviteCode) error Delete(id uint) error GetStatistics() (*models.InviteCodeStatistics, error) IsExpired(inviteCode *models.InviteCode) bool } type inviteCodeStorage struct { db *gorm.DB } func NewInviteCodeStorage() InviteCodeStorage { return &inviteCodeStorage{db: DB} } func (s *inviteCodeStorage) Create(inviteCode *models.InviteCode) error { // 若目标库缺少 is_used 列,则在插入时省略该列,避免 Unknown column 错误 if columnExistsIsUsed(s.db) { return s.db.Create(inviteCode).Error } return s.db.Omit("is_used").Create(inviteCode).Error } func (s *inviteCodeStorage) GetByID(id uint) (*models.InviteCode, error) { var inviteCode models.InviteCode err := s.db.Where("deleted_at IS NULL").First(&inviteCode, id).Error if err != nil { return nil, err } return &inviteCode, nil } func (s *inviteCodeStorage) GetByCode(code string) (*models.InviteCode, error) { var inviteCode models.InviteCode err := s.db.Where("code = ? AND deleted_at IS NULL", code).First(&inviteCode).Error if err != nil { return nil, err } return &inviteCode, nil } // IsExpired 检查邀请码是否过期 func (s *inviteCodeStorage) IsExpired(inviteCode *models.InviteCode) bool { if inviteCode.ExpiresAt == nil { return false // 没有设置过期时间,永不过期 } return time.Now().After(*inviteCode.ExpiresAt) } func (s *inviteCodeStorage) List(req *models.InviteCodeListRequest) ([]models.InviteCode, int64, error) { var inviteCodes []models.InviteCode var total int64 query := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL") // 筛选条件 if req.Code != "" { query = query.Where("code LIKE ?", "%"+req.Code+"%") } // 仅当存在 is_used 列时才应用过滤 if req.IsUsed != nil { if columnExistsIsUsed(s.db) { query = query.Where("is_used = ?", *req.IsUsed) } } if req.StartTime != "" { query = query.Where("created_at >= ?", req.StartTime) } if req.EndTime != "" { query = query.Where("created_at <= ?", req.EndTime) } // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 分页查询 offset := (req.Page - 1) * req.Size err := query.Order("created_at DESC").Offset(offset).Limit(req.Size).Find(&inviteCodes).Error if err != nil { return nil, 0, err } return inviteCodes, total, nil } func (s *inviteCodeStorage) Update(inviteCode *models.InviteCode) error { return s.db.Save(inviteCode).Error } func (s *inviteCodeStorage) Delete(id uint) error { return s.db.Delete(&models.InviteCode{}, id).Error } func (s *inviteCodeStorage) GetStatistics() (*models.InviteCodeStatistics, error) { var stats models.InviteCodeStatistics // 总数 var total int64 if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL").Count(&total).Error; err != nil { return nil, err } stats.Total = int(total) // is_used 列可能不存在,存在时统计已使用/未使用 if columnExistsIsUsed(s.db) { var used int64 if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND is_used = ?", true).Count(&used).Error; err != nil { return nil, err } stats.Used = int(used) var unused int64 if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND is_used = ?", false).Count(&unused).Error; err != nil { return nil, err } stats.Unused = int(unused) } else { // 列不存在时,给出合理默认值 stats.Used = 0 stats.Unused = int(total) } // 今日新增 today := time.Now().Format("2006-01-02") var todayCreated int64 if err := s.db.Model(&models.InviteCode{}).Where("deleted_at IS NULL AND DATE(created_at) = ?", today).Count(&todayCreated).Error; err != nil { return nil, err } stats.TodayCreated = int(todayCreated) return &stats, nil } // columnExistsIsUsed 检查当前数据库中 admin_invite_codes 表是否存在 is_used 列 func columnExistsIsUsed(db *gorm.DB) bool { var count int64 // 使用当前连接的数据库名 db.Raw("SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND COLUMN_NAME = 'is_used'", "admin_invite_codes").Scan(&count) return count > 0 }