package service import ( "context" "errors" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) // tierAccessRepoStub satisfies AccountRepository with a hand-rolled // ListByPlatform; every other method panics so accidental calls are loud. type tierAccessRepoStub struct { accounts []Account err error } func (s *tierAccessRepoStub) ListByPlatform(_ context.Context, _ string) ([]Account, error) { return s.accounts, s.err } func (*tierAccessRepoStub) Create(context.Context, *Account) error { panic("unexpected") } func (*tierAccessRepoStub) GetByID(context.Context, int64) (*Account, error) { panic("unexpected") } func (*tierAccessRepoStub) GetByIDs(context.Context, []int64) ([]*Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ExistsByID(context.Context, int64) (bool, error) { panic("unexpected") } func (*tierAccessRepoStub) GetByCRSAccountID(context.Context, string) (*Account, error) { panic("unexpected") } func (*tierAccessRepoStub) FindByExtraField(context.Context, string, any) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) FindByCredentialField(context.Context, string, string, string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListCRSAccountIDs(context.Context) (map[string]int64, error) { panic("unexpected") } func (*tierAccessRepoStub) Update(context.Context, *Account) error { panic("unexpected") } func (*tierAccessRepoStub) Delete(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (*tierAccessRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (*tierAccessRepoStub) ListByGroup(context.Context, int64) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListActive(context.Context) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) UpdateLastUsed(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { panic("unexpected") } func (*tierAccessRepoStub) SetError(context.Context, int64, string) error { panic("unexpected") } func (*tierAccessRepoStub) ClearError(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) SetSchedulable(context.Context, int64, bool) error { panic("unexpected") } func (*tierAccessRepoStub) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { panic("unexpected") } func (*tierAccessRepoStub) BindGroups(context.Context, int64, []int64) error { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulable(context.Context) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableByPlatform(context.Context, string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) { panic("unexpected") } func (*tierAccessRepoStub) SetRateLimited(context.Context, int64, time.Time) error { panic("unexpected") } func (*tierAccessRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time, ...string) error { panic("unexpected") } func (*tierAccessRepoStub) SetOverloaded(context.Context, int64, time.Time) error { panic("unexpected") } func (*tierAccessRepoStub) SetTempUnschedulable(context.Context, int64, time.Time, string) error { panic("unexpected") } func (*tierAccessRepoStub) ClearTempUnschedulable(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) ClearRateLimit(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) ClearAntigravityQuotaScopes(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) ClearModelRateLimits(context.Context, int64) error { panic("unexpected") } func (*tierAccessRepoStub) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { panic("unexpected") } func (*tierAccessRepoStub) UpdateExtra(context.Context, int64, map[string]any) error { panic("unexpected") } func (*tierAccessRepoStub) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) { panic("unexpected") } func (*tierAccessRepoStub) IncrementQuotaUsed(context.Context, int64, float64) error { panic("unexpected") } func (*tierAccessRepoStub) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") } func mkAccount(id int64, tier string, status string, allowed []WindsurfAllowedModel, caps map[string]WindsurfModelCapability) Account { creds := WindsurfCredentials{ APIKey: "key-" + tier, Tier: tier, } extra := WindsurfExtra{ UserStatus: WindsurfUserStatusSnapshot{AllowedModels: allowed}, Capabilities: caps, } return Account{ ID: id, Platform: domain.PlatformWindsurf, Status: status, Schedulable: status == StatusActive, Credentials: StoreWindsurfCredentials(creds), Extra: StoreWindsurfExtra(extra), } } func TestWindsurfTierAccessService_Snapshot_HappyPath(t *testing.T) { repo := &tierAccessRepoStub{ accounts: []Account{ mkAccount(1, "free", StatusActive, []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "kimi-k2"}}, nil), mkAccount(2, "pro", StatusActive, []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "claude-sonnet-4.6"}}, nil), mkAccount(3, "trial", StatusActive, []WindsurfAllowedModel{{ModelKey: "claude-sonnet-4.6"}}, nil), }, } svc := NewWindsurfTierAccessService(repo) snap, err := svc.Snapshot(context.Background()) if err != nil { t.Fatalf("snapshot: %v", err) } if snap.Accounts != 3 { t.Fatalf("expected 3 accounts considered, got %d", snap.Accounts) } rowsByModel := make(map[string]WindsurfTierAccessRow) for _, r := range snap.Rows { rowsByModel[r.Model] = r } if got := rowsByModel["gemini-2.5-flash"]; got.Free != 1 || got.Pro != 1 || got.Trial != 0 || got.Total != 2 { t.Fatalf("gemini-2.5-flash unexpected counts: %+v", got) } if got := rowsByModel["claude-sonnet-4.6"]; got.Free != 0 || got.Pro != 1 || got.Trial != 1 || got.Total != 2 { t.Fatalf("claude-sonnet-4.6 unexpected counts: %+v", got) } if got := rowsByModel["kimi-k2"]; got.Free != 1 { t.Fatalf("kimi-k2 unexpected counts: %+v", got) } } func TestWindsurfTierAccessService_Snapshot_BlockedAccountsCounted(t *testing.T) { caps := map[string]WindsurfModelCapability{ "gemini-2.5-flash": {Available: false, Reason: "not_entitled"}, } repo := &tierAccessRepoStub{ accounts: []Account{ mkAccount(1, "free", StatusActive, nil, caps), mkAccount(2, "free", "paused", []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}}, nil), }, } svc := NewWindsurfTierAccessService(repo) snap, _ := svc.Snapshot(context.Background()) row := findTierRow(snap, "gemini-2.5-flash") if row == nil { t.Fatal("expected gemini-2.5-flash row") } if row.Blocked != 2 || row.Total != 0 { t.Fatalf("expected blocked=2 total=0, got %+v", row) } } func TestWindsurfTierAccessService_Snapshot_SkipsUnregisteredAccounts(t *testing.T) { acct := Account{ ID: 1, Platform: domain.PlatformWindsurf, Status: StatusActive, Schedulable: true, Credentials: StoreWindsurfCredentials(WindsurfCredentials{Email: "a@b.c"}), // no APIKey } svc := NewWindsurfTierAccessService(&tierAccessRepoStub{accounts: []Account{acct}}) snap, _ := svc.Snapshot(context.Background()) if snap.Accounts != 0 { t.Fatalf("expected accounts considered=0, got %d", snap.Accounts) } if len(snap.Rows) != 0 { t.Fatalf("expected no rows, got %+v", snap.Rows) } } func TestWindsurfTierAccessService_Snapshot_PropagatesRepoError(t *testing.T) { svc := NewWindsurfTierAccessService(&tierAccessRepoStub{err: errors.New("db down")}) if _, err := svc.Snapshot(context.Background()); err == nil { t.Fatal("expected error") } } func TestWindsurfTierAccessService_Snapshot_CachesWithinTTL(t *testing.T) { repo := &tierAccessRepoStub{ accounts: []Account{ mkAccount(1, "free", StatusActive, []WindsurfAllowedModel{{ModelKey: "x"}}, nil), }, } svc := NewWindsurfTierAccessService(repo) first, _ := svc.Snapshot(context.Background()) // Pointer equality is the cache-hit signal: atomic.Pointer.Store fires // only on rebuild, so a returned pointer identical to the prior call // proves the build() path was skipped. We mutate the underlying repo // to a state that would yield a different snapshot if the rebuild // actually ran — the assertion below catches a regression in either // direction (TTL gate broken, or sync.Once style misuse). The 60s // default TTL is large enough that this test never sees an expiry // during normal CI runs. repo.accounts = nil second, _ := svc.Snapshot(context.Background()) if first != second { t.Fatal("expected cached pointer reuse") } } func findTierRow(snap *WindsurfTierAccessSnapshot, model string) *WindsurfTierAccessRow { for i := range snap.Rows { if snap.Rows[i].Model == model { return &snap.Rows[i] } } return nil }