package service import ( "context" "errors" "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/domain" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" ) // tokenLoginRepoStub is a minimal AccountRepository stub used by // TestWindsurfAuthService_TokenLogin_*. It implements just FindByCredentialField // (the only repo method TokenLogin reaches before the validation short-circuits). // All other methods panic so accidental calls are loud. type tokenLoginRepoStub struct { existing []Account findErr error } func (s *tokenLoginRepoStub) FindByCredentialField(_ context.Context, _, _, _ string) ([]Account, error) { return s.existing, s.findErr } func (*tokenLoginRepoStub) Create(context.Context, *Account) error { panic("unexpected") } func (*tokenLoginRepoStub) GetByID(context.Context, int64) (*Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) GetByIDs(context.Context, []int64) ([]*Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ExistsByID(context.Context, int64) (bool, error) { panic("unexpected") } func (*tokenLoginRepoStub) GetByCRSAccountID(context.Context, string) (*Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) FindByExtraField(context.Context, string, any) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListCRSAccountIDs(context.Context) (map[string]int64, error) { panic("unexpected") } func (*tokenLoginRepoStub) Update(context.Context, *Account) error { panic("unexpected") } func (*tokenLoginRepoStub) Delete(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListByGroup(context.Context, int64) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListActive(context.Context) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListByPlatform(context.Context, string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) UpdateLastUsed(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { panic("unexpected") } func (*tokenLoginRepoStub) SetError(context.Context, int64, string) error { panic("unexpected") } func (*tokenLoginRepoStub) ClearError(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) SetSchedulable(context.Context, int64, bool) error { panic("unexpected") } func (*tokenLoginRepoStub) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { panic("unexpected") } func (*tokenLoginRepoStub) BindGroups(context.Context, int64, []int64) error { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulable(context.Context) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableByPlatform(context.Context, string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) { panic("unexpected") } func (*tokenLoginRepoStub) SetRateLimited(context.Context, int64, time.Time) error { panic("unexpected") } func (*tokenLoginRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error { panic("unexpected") } func (*tokenLoginRepoStub) SetOverloaded(context.Context, int64, time.Time) error { panic("unexpected") } func (*tokenLoginRepoStub) SetTempUnschedulable(context.Context, int64, time.Time, string) error { panic("unexpected") } func (*tokenLoginRepoStub) ClearTempUnschedulable(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) ClearRateLimit(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) ClearAntigravityQuotaScopes(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) ClearModelRateLimits(context.Context, int64) error { panic("unexpected") } func (*tokenLoginRepoStub) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { panic("unexpected") } func (*tokenLoginRepoStub) UpdateExtra(context.Context, int64, map[string]any) error { panic("unexpected") } func (*tokenLoginRepoStub) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) { panic("unexpected") } func (*tokenLoginRepoStub) IncrementQuotaUsed(context.Context, int64, float64) error { panic("unexpected") } func (*tokenLoginRepoStub) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") } // TestWindsurfAuthService_TokenLogin_Validation exercises input validation and // dedup short-circuits in TokenLogin (these run before any external dependency // is touched). func TestWindsurfAuthService_TokenLogin_Validation(t *testing.T) { tests := []struct { name string input *WindsurfTokenLoginInput repo *tokenLoginRepoStub wantErr string }{ { name: "empty token rejected", input: &WindsurfTokenLoginInput{Email: "user@example.com"}, repo: &tokenLoginRepoStub{}, wantErr: "token required", }, { name: "duplicate email rejected with conflict error", input: &WindsurfTokenLoginInput{Token: "fake-token", Email: "dup@example.com"}, repo: &tokenLoginRepoStub{existing: []Account{{ID: 42}}}, wantErr: "already exists", }, { name: "find error propagated", input: &WindsurfTokenLoginInput{Token: "fake-token", Email: "boom@example.com"}, repo: &tokenLoginRepoStub{findErr: errors.New("db down")}, wantErr: "check existing account", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { svc := &WindsurfAuthService{ accountRepo: tc.repo, authClient: &windsurf.AuthClient{}, } _, err := svc.TokenLogin(context.Background(), tc.input) if err == nil { t.Fatalf("expected error containing %q, got nil", tc.wantErr) } if !strings.Contains(err.Error(), tc.wantErr) { t.Fatalf("expected error containing %q, got %q", tc.wantErr, err.Error()) } }) } } // TestWindsurfAuthService_TokenLogin_PlatformConst guards against accidental // drift in the platform/type constants used to persist the account. func TestWindsurfAuthService_TokenLogin_PlatformConst(t *testing.T) { if domain.PlatformWindsurf == "" { t.Fatal("PlatformWindsurf constant is empty") } if domain.AccountTypeWindsurfSession == "" { t.Fatal("AccountTypeWindsurfSession constant is empty") } } // TestWindsurfAuthService_TokenLogin_TypedErrors verifies that validation // failures surface as ApplicationError with the right HTTP code, so the // handler maps them to 4xx instead of 500. func TestWindsurfAuthService_TokenLogin_TypedErrors(t *testing.T) { cases := []struct { name string input *WindsurfTokenLoginInput repo *tokenLoginRepoStub wantCode int }{ { name: "missing token returns 400", input: &WindsurfTokenLoginInput{Email: "x@y.z"}, repo: &tokenLoginRepoStub{}, wantCode: 400, }, { name: "duplicate email returns 409", input: &WindsurfTokenLoginInput{Token: "tok", Email: "dup@example.com"}, repo: &tokenLoginRepoStub{existing: []Account{{ID: 1}}}, wantCode: 409, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { svc := &WindsurfAuthService{ accountRepo: tc.repo, authClient: &windsurf.AuthClient{}, } _, err := svc.TokenLogin(context.Background(), tc.input) if err == nil { t.Fatalf("expected error, got nil") } if got := infraerrors.Code(err); got != tc.wantCode { t.Fatalf("expected HTTP code %d, got %d (err=%v)", tc.wantCode, got, err) } }) } }