From a4884b4e758bda7076074f17df2f2dd7376996bd Mon Sep 17 00:00:00 2001 From: benjamin Date: Mon, 18 May 2026 21:09:11 +0800 Subject: [PATCH] =?UTF-8?q?fix(subscription):=20=E5=B0=86=E6=97=A5?= =?UTF-8?q?=E5=8D=A1=E6=94=B9=E4=B8=BA=E4=B8=80=E6=AC=A1=E6=80=A7=E6=AF=8F?= =?UTF-8?q?=E6=97=A5=E9=85=8D=E9=A2=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .../subscription_assign_idempotency_test.go | 18 ++ .../subscription_calculate_progress_test.go | 24 +++ .../internal/service/subscription_service.go | 133 +++++++++---- backend/internal/service/user_subscription.go | 20 +- .../user_subscription_daily_quota_test.go | 178 ++++++++++++++++++ 5 files changed, 334 insertions(+), 39 deletions(-) create mode 100644 backend/internal/service/user_subscription_daily_quota_test.go diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go index 40bab206..c8ace613 100644 --- a/backend/internal/service/subscription_assign_idempotency_test.go +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -199,6 +199,24 @@ func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*Use return &cp, nil } +func (s *subscriptionUserSubRepoStub) Update(_ context.Context, sub *UserSubscription) error { + if sub == nil { + return ErrSubscriptionNilInput + } + existing := s.byID[sub.ID] + if existing == nil { + return ErrSubscriptionNotFound + } + oldKey := s.key(existing.UserID, existing.GroupID) + cp := *sub + s.byID[cp.ID] = &cp + if oldKey != s.key(cp.UserID, cp.GroupID) { + delete(s.byUserGroup, oldKey) + } + s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp + return nil +} + func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) { start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) groupRepo := &subscriptionGroupRepoStub{ diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index 53e5c568..650522d5 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -66,6 +66,30 @@ func TestCalculateProgress_DailyUsage(t *testing.T) { assert.Equal(t, dailyStart, progress.Daily.WindowStart) } +func TestCalculateProgress_DailyCardUsesExpiryAsDailyResetTime(t *testing.T) { + svc := newTestSubscriptionService() + startsAt := time.Now().Add(-12 * time.Hour) + dailyStart := time.Date(startsAt.Year(), startsAt.Month(), startsAt.Day(), 0, 0, 0, 0, startsAt.Location()) + expiresAt := startsAt.Add(24 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + StartsAt: startsAt, + ExpiresAt: expiresAt, + DailyUsageUSD: 3.0, + DailyWindowStart: ptrTime(dailyStart), + } + group := &Group{ + Name: "Daily", + DailyLimitUSD: ptrFloat64(10.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily, "日卡有日限额和窗口时 Daily 不应为 nil") + assert.Equal(t, expiresAt, progress.Daily.ResetsAt, "日卡的一次性日额度结束时间应为订阅过期时间") +} + func TestCalculateProgress_WeeklyUsage(t *testing.T) { svc := newTestSubscriptionService() now := time.Now() diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index f0a5540e..9905e6a1 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -196,7 +196,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in now := time.Now() var newExpiresAt time.Time - if existingSub.ExpiresAt.After(now) { + isExpired := !existingSub.ExpiresAt.After(now) + if !isExpired { // 未过期:从当前过期时间累加 newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays) } else { @@ -209,43 +210,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in newExpiresAt = MaxExpiresAt } - // 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成 - tx, err := s.entClient.Tx(ctx) - if err != nil { - return nil, false, fmt.Errorf("begin transaction: %w", err) - } - txCtx := dbent.NewTxContext(ctx, tx) - - // 更新过期时间 - if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("extend subscription: %w", err) - } - - // 如果订阅已过期或被暂停,恢复为active状态 - if existingSub.Status != SubscriptionStatusActive { - if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("update subscription status: %w", err) - } - } - - // 追加备注 - if input.Notes != "" { - newNotes := existingSub.Notes - if newNotes != "" { - newNotes += "\n" - } - newNotes += input.Notes - if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("update subscription notes: %w", err) - } - } - - // 提交事务 - if err := tx.Commit(); err != nil { - return nil, false, fmt.Errorf("commit transaction: %w", err) + if err := s.updateExistingSubscriptionTerm(ctx, existingSub, input.Notes, now, newExpiresAt, isExpired); err != nil { + return nil, false, err } // 失效订阅缓存 @@ -284,6 +250,94 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in return sub, false, nil // false 表示是新建 } +func (s *SubscriptionService) updateExistingSubscriptionTerm( + ctx context.Context, + existingSub *UserSubscription, + notes string, + startsAt time.Time, + newExpiresAt time.Time, + isExpired bool, +) error { + return s.withSubscriptionUpdateTx(ctx, func(txCtx context.Context) error { + if isExpired { + renewed := renewedSubscriptionTerm(existingSub, notes, startsAt, newExpiresAt) + if err := s.userSubRepo.Update(txCtx, renewed); err != nil { + return fmt.Errorf("renew expired subscription: %w", err) + } + return nil + } + + // 更新过期时间 + if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil { + return fmt.Errorf("extend subscription: %w", err) + } + + // 如果订阅被暂停,恢复为 active 状态 + if existingSub.Status != SubscriptionStatusActive { + if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil { + return fmt.Errorf("update subscription status: %w", err) + } + } + + // 追加备注 + if notes != "" { + if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, appendSubscriptionNotes(existingSub.Notes, notes)); err != nil { + return fmt.Errorf("update subscription notes: %w", err) + } + } + + return nil + }) +} + +func (s *SubscriptionService) withSubscriptionUpdateTx(ctx context.Context, fn func(context.Context) error) error { + if s.entClient == nil { + return fn(ctx) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + txCtx := dbent.NewTxContext(ctx, tx) + + if err := fn(txCtx); err != nil { + _ = tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} + +func renewedSubscriptionTerm(existingSub *UserSubscription, notes string, startsAt, expiresAt time.Time) *UserSubscription { + renewed := *existingSub + windowStart := startOfDay(startsAt) + renewed.StartsAt = startsAt + renewed.ExpiresAt = expiresAt + renewed.Status = SubscriptionStatusActive + renewed.DailyWindowStart = &windowStart + renewed.WeeklyWindowStart = &windowStart + renewed.MonthlyWindowStart = &windowStart + renewed.DailyUsageUSD = 0 + renewed.WeeklyUsageUSD = 0 + renewed.MonthlyUsageUSD = 0 + renewed.Notes = appendSubscriptionNotes(existingSub.Notes, notes) + return &renewed +} + +func appendSubscriptionNotes(existingNotes, newNotes string) string { + if newNotes == "" { + return existingNotes + } + if existingNotes == "" { + return newNotes + } + return existingNotes + "\n" + newNotes +} + // createSubscription 创建新订阅(内部方法) func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { validityDays := input.ValidityDays @@ -945,6 +999,9 @@ func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Gr if group.HasDailyLimit() && sub.DailyWindowStart != nil { limit := *group.DailyLimitUSD resetsAt := sub.DailyWindowStart.Add(24 * time.Hour) + if dailyResetTime := sub.DailyResetTime(); dailyResetTime != nil { + resetsAt = *dailyResetTime + } progress.Daily = &UsageWindowProgress{ LimitUSD: limit, UsedUSD: sub.DailyUsageUSD, diff --git a/backend/internal/service/user_subscription.go b/backend/internal/service/user_subscription.go index ec547d81..6303e6e3 100644 --- a/backend/internal/service/user_subscription.go +++ b/backend/internal/service/user_subscription.go @@ -50,11 +50,25 @@ func (s *UserSubscription) IsWindowActivated() bool { return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil } +func (s *UserSubscription) HasOneTimeDailyQuota() bool { + if s == nil || s.StartsAt.IsZero() || s.ExpiresAt.IsZero() { + return false + } + return !s.ExpiresAt.After(s.StartsAt.AddDate(0, 0, 1)) +} + func (s *UserSubscription) NeedsDailyReset() bool { + return s.NeedsDailyResetAt(time.Now()) +} + +func (s *UserSubscription) NeedsDailyResetAt(now time.Time) bool { if s.DailyWindowStart == nil { return false } - return time.Since(*s.DailyWindowStart) >= 24*time.Hour + if s.HasOneTimeDailyQuota() { + return false + } + return !now.Before(s.DailyWindowStart.Add(24 * time.Hour)) } func (s *UserSubscription) NeedsWeeklyReset() bool { @@ -75,6 +89,10 @@ func (s *UserSubscription) DailyResetTime() *time.Time { if s.DailyWindowStart == nil { return nil } + if s.HasOneTimeDailyQuota() { + t := s.ExpiresAt + return &t + } t := s.DailyWindowStart.Add(24 * time.Hour) return &t } diff --git a/backend/internal/service/user_subscription_daily_quota_test.go b/backend/internal/service/user_subscription_daily_quota_test.go new file mode 100644 index 00000000..3738bdd6 --- /dev/null +++ b/backend/internal/service/user_subscription_daily_quota_test.go @@ -0,0 +1,178 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type dailyResetTrackingUserSubRepo struct { + userSubRepoNoop + + resetDailyCalled bool +} + +func (r *dailyResetTrackingUserSubRepo) ResetDailyUsage(context.Context, int64, time.Time) error { + r.resetDailyCalled = true + return nil +} + +func TestAssignOrExtendSubscription_ExpiredDailyCardStartsNewOneTimeQuota(t *testing.T) { + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + oldStart := time.Now().AddDate(0, 0, -3) + oldWindowStart := startOfDay(oldStart) + subRepo.seed(&UserSubscription{ + ID: 100, + UserID: 200, + GroupID: 1, + StartsAt: oldStart, + ExpiresAt: oldStart.AddDate(0, 0, 1), + Status: SubscriptionStatusExpired, + DailyWindowStart: &oldWindowStart, + WeeklyWindowStart: &oldWindowStart, + MonthlyWindowStart: &oldWindowStart, + DailyUsageUSD: 10, + WeeklyUsageUSD: 20, + MonthlyUsageUSD: 30, + Notes: "old", + }) + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + + renewed, reused, err := svc.AssignOrExtendSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 200, + GroupID: 1, + ValidityDays: 1, + Notes: "new", + }) + + require.NoError(t, err) + require.True(t, reused) + require.True(t, renewed.HasOneTimeDailyQuota(), "过期后重新购买 1 日卡仍应被识别为一次性日额度") + require.Equal(t, SubscriptionStatusActive, renewed.Status) + require.True(t, renewed.StartsAt.After(oldStart), "重新购买过期订阅时应重置当前周期 StartsAt") + require.False(t, renewed.ExpiresAt.After(renewed.StartsAt.AddDate(0, 0, 1))) + require.NotNil(t, renewed.DailyWindowStart) + require.Equal(t, startOfDay(renewed.StartsAt), *renewed.DailyWindowStart) + require.Equal(t, 0.0, renewed.DailyUsageUSD) + require.Equal(t, 0.0, renewed.WeeklyUsageUSD) + require.Equal(t, 0.0, renewed.MonthlyUsageUSD) + require.Equal(t, "old\nnew", renewed.Notes) +} + +func TestUserSubscriptionNeedsDailyReset_DailyCardKeepsOneTimeQuota(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: start.Add(24 * time.Hour), + DailyWindowStart: &dailyWindowStart, + DailyUsageUSD: 10, + } + + require.True(t, sub.HasOneTimeDailyQuota()) + require.False(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(25*time.Hour)), "日卡应作为一次性配额,跨 0 点后不再刷新日额度") +} + +func TestUserSubscriptionNeedsDailyReset_MultiDaySubscriptionStillRefreshes(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 2), + DailyWindowStart: &dailyWindowStart, + } + + require.False(t, sub.HasOneTimeDailyQuota()) + require.True(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(24*time.Hour)), "多日订阅仍应按 24 小时日窗口刷新") +} + +func TestUserSubscriptionDailyResetTime_DailyCardReturnsExpiry(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + expiresAt := start.Add(24 * time.Hour) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: expiresAt, + DailyWindowStart: &dailyWindowStart, + } + + resetAt := sub.DailyResetTime() + require.NotNil(t, resetAt) + require.Equal(t, expiresAt, *resetAt, "日卡展示的日额度结束时间应为订阅过期时间") +} + +func TestCheckAndResetWindows_DailyCardDoesNotResetDailyUsage(t *testing.T) { + now := time.Now() + startsAt := now.Add(-23 * time.Hour) + dailyWindowStart := now.Add(-25 * time.Hour) + repo := &dailyResetTrackingUserSubRepo{} + svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil) + sub := &UserSubscription{ + ID: 1, + UserID: 10, + GroupID: 20, + StartsAt: startsAt, + ExpiresAt: startsAt.Add(24 * time.Hour), + DailyUsageUSD: 10, + DailyWindowStart: &dailyWindowStart, + } + + err := svc.CheckAndResetWindows(context.Background(), sub) + + require.NoError(t, err) + require.False(t, repo.resetDailyCalled, "日卡作为一次性配额,过了 24 小时日窗口也不应重置 daily usage") + require.Equal(t, 10.0, sub.DailyUsageUSD) +} + +func TestCheckAndResetWindows_MultiDaySubscriptionStillResetsDailyUsage(t *testing.T) { + now := time.Now() + startsAt := now.Add(-48 * time.Hour) + dailyWindowStart := now.Add(-25 * time.Hour) + repo := &dailyResetTrackingUserSubRepo{} + svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil) + sub := &UserSubscription{ + ID: 1, + UserID: 10, + GroupID: 20, + StartsAt: startsAt, + ExpiresAt: startsAt.AddDate(0, 0, 2), + DailyUsageUSD: 10, + DailyWindowStart: &dailyWindowStart, + } + + err := svc.CheckAndResetWindows(context.Background(), sub) + + require.NoError(t, err) + require.True(t, repo.resetDailyCalled, "多日订阅仍应重置过期 daily window") + require.Equal(t, 0.0, sub.DailyUsageUSD) +} + +func TestValidateAndCheckLimits_DailyCardDoesNotAllowSecondQuotaAfterMidnight(t *testing.T) { + start := time.Now().Add(-23 * time.Hour) + dailyWindowStart := time.Now().Add(-25 * time.Hour) + dailyLimit := 10.0 + sub := &UserSubscription{ + Status: SubscriptionStatusActive, + StartsAt: start, + ExpiresAt: start.Add(24 * time.Hour), + DailyWindowStart: &dailyWindowStart, + DailyUsageUSD: dailyLimit + 0.01, + } + group := &Group{ + SubscriptionType: SubscriptionTypeSubscription, + DailyLimitUSD: &dailyLimit, + } + svc := NewSubscriptionService(groupRepoNoop{}, userSubRepoNoop{}, nil, nil, nil) + + needsMaintenance, err := svc.ValidateAndCheckLimits(sub, group) + + require.False(t, needsMaintenance, "日卡跨过日窗口后不应触发 daily reset 维护") + require.True(t, errors.Is(err, ErrDailyLimitExceeded)) + require.Equal(t, dailyLimit+0.01, sub.DailyUsageUSD, "热路径不应清零日卡已用额度") +}