sub2api/backend/internal/service/windsurf_tier_access_service_test.go
win de048fad25 chore(wip): save Windsurf/Antigravity/ops customizations before upstream merge
WIP commit保存以下定制工作以便后续合并 upstream v0.1.124-125:
- Windsurf: tier access service, NLU extractor, cold threshold, Google login
- Antigravity: client/oauth 调整
- Ops: log stream handler/broadcaster/middleware, OpsLogStreamView
- Frontend: WindsurfLoginModal Google, GoogleIcon, AccountsView, sidebar/router/i18n
2026-05-09 00:41:19 +08:00

267 lines
9.7 KiB
Go

package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// tierAccessRepoStub satisfies AccountRepository with a hand-rolled
// ListByPlatform; every other method panics so accidental calls are loud.
type tierAccessRepoStub struct {
accounts []Account
err error
}
func (s *tierAccessRepoStub) ListByPlatform(_ context.Context, _ string) ([]Account, error) {
return s.accounts, s.err
}
func (*tierAccessRepoStub) Create(context.Context, *Account) error { panic("unexpected") }
func (*tierAccessRepoStub) GetByID(context.Context, int64) (*Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) GetByIDs(context.Context, []int64) ([]*Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ExistsByID(context.Context, int64) (bool, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) GetByCRSAccountID(context.Context, string) (*Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) FindByExtraField(context.Context, string, any) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) FindByCredentialField(context.Context, string, string, string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) Update(context.Context, *Account) error { panic("unexpected") }
func (*tierAccessRepoStub) Delete(context.Context, int64) error { panic("unexpected") }
func (*tierAccessRepoStub) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListByGroup(context.Context, int64) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListActive(context.Context) ([]Account, error) { panic("unexpected") }
func (*tierAccessRepoStub) UpdateLastUsed(context.Context, int64) error { panic("unexpected") }
func (*tierAccessRepoStub) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
panic("unexpected")
}
func (*tierAccessRepoStub) SetError(context.Context, int64, string) error { panic("unexpected") }
func (*tierAccessRepoStub) ClearError(context.Context, int64) error { panic("unexpected") }
func (*tierAccessRepoStub) SetSchedulable(context.Context, int64, bool) error {
panic("unexpected")
}
func (*tierAccessRepoStub) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) BindGroups(context.Context, int64, []int64) error {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulable(context.Context) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableByPlatform(context.Context, string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) SetRateLimited(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (*tierAccessRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error {
panic("unexpected")
}
func (*tierAccessRepoStub) SetOverloaded(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (*tierAccessRepoStub) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
panic("unexpected")
}
func (*tierAccessRepoStub) ClearTempUnschedulable(context.Context, int64) error {
panic("unexpected")
}
func (*tierAccessRepoStub) ClearRateLimit(context.Context, int64) error { panic("unexpected") }
func (*tierAccessRepoStub) ClearAntigravityQuotaScopes(context.Context, int64) error {
panic("unexpected")
}
func (*tierAccessRepoStub) ClearModelRateLimits(context.Context, int64) error {
panic("unexpected")
}
func (*tierAccessRepoStub) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
panic("unexpected")
}
func (*tierAccessRepoStub) UpdateExtra(context.Context, int64, map[string]any) error {
panic("unexpected")
}
func (*tierAccessRepoStub) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) {
panic("unexpected")
}
func (*tierAccessRepoStub) IncrementQuotaUsed(context.Context, int64, float64) error {
panic("unexpected")
}
func (*tierAccessRepoStub) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") }
func mkAccount(id int64, tier string, status string, allowed []WindsurfAllowedModel, caps map[string]WindsurfModelCapability) Account {
creds := WindsurfCredentials{
APIKey: "key-" + tier,
Tier: tier,
}
extra := WindsurfExtra{
UserStatus: WindsurfUserStatusSnapshot{AllowedModels: allowed},
Capabilities: caps,
}
return Account{
ID: id,
Platform: domain.PlatformWindsurf,
Status: status,
Schedulable: status == StatusActive,
Credentials: StoreWindsurfCredentials(creds),
Extra: StoreWindsurfExtra(extra),
}
}
func TestWindsurfTierAccessService_Snapshot_HappyPath(t *testing.T) {
repo := &tierAccessRepoStub{
accounts: []Account{
mkAccount(1, "free", StatusActive,
[]WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "kimi-k2"}},
nil),
mkAccount(2, "pro", StatusActive,
[]WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "claude-sonnet-4.6"}},
nil),
mkAccount(3, "trial", StatusActive,
[]WindsurfAllowedModel{{ModelKey: "claude-sonnet-4.6"}},
nil),
},
}
svc := NewWindsurfTierAccessService(repo)
snap, err := svc.Snapshot(context.Background())
if err != nil {
t.Fatalf("snapshot: %v", err)
}
if snap.Accounts != 3 {
t.Fatalf("expected 3 accounts considered, got %d", snap.Accounts)
}
rowsByModel := make(map[string]WindsurfTierAccessRow)
for _, r := range snap.Rows {
rowsByModel[r.Model] = r
}
if got := rowsByModel["gemini-2.5-flash"]; got.Free != 1 || got.Pro != 1 || got.Trial != 0 || got.Total != 2 {
t.Fatalf("gemini-2.5-flash unexpected counts: %+v", got)
}
if got := rowsByModel["claude-sonnet-4.6"]; got.Free != 0 || got.Pro != 1 || got.Trial != 1 || got.Total != 2 {
t.Fatalf("claude-sonnet-4.6 unexpected counts: %+v", got)
}
if got := rowsByModel["kimi-k2"]; got.Free != 1 {
t.Fatalf("kimi-k2 unexpected counts: %+v", got)
}
}
func TestWindsurfTierAccessService_Snapshot_BlockedAccountsCounted(t *testing.T) {
caps := map[string]WindsurfModelCapability{
"gemini-2.5-flash": {Available: false, Reason: "not_entitled"},
}
repo := &tierAccessRepoStub{
accounts: []Account{
mkAccount(1, "free", StatusActive, nil, caps),
mkAccount(2, "free", "paused", []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}}, nil),
},
}
svc := NewWindsurfTierAccessService(repo)
snap, _ := svc.Snapshot(context.Background())
row := findTierRow(snap, "gemini-2.5-flash")
if row == nil {
t.Fatal("expected gemini-2.5-flash row")
}
if row.Blocked != 2 || row.Total != 0 {
t.Fatalf("expected blocked=2 total=0, got %+v", row)
}
}
func TestWindsurfTierAccessService_Snapshot_SkipsUnregisteredAccounts(t *testing.T) {
acct := Account{
ID: 1,
Platform: domain.PlatformWindsurf,
Status: StatusActive,
Schedulable: true,
Credentials: StoreWindsurfCredentials(WindsurfCredentials{Email: "a@b.c"}), // no APIKey
}
svc := NewWindsurfTierAccessService(&tierAccessRepoStub{accounts: []Account{acct}})
snap, _ := svc.Snapshot(context.Background())
if snap.Accounts != 0 {
t.Fatalf("expected accounts considered=0, got %d", snap.Accounts)
}
if len(snap.Rows) != 0 {
t.Fatalf("expected no rows, got %+v", snap.Rows)
}
}
func TestWindsurfTierAccessService_Snapshot_PropagatesRepoError(t *testing.T) {
svc := NewWindsurfTierAccessService(&tierAccessRepoStub{err: errors.New("db down")})
if _, err := svc.Snapshot(context.Background()); err == nil {
t.Fatal("expected error")
}
}
func TestWindsurfTierAccessService_Snapshot_CachesWithinTTL(t *testing.T) {
repo := &tierAccessRepoStub{
accounts: []Account{
mkAccount(1, "free", StatusActive, []WindsurfAllowedModel{{ModelKey: "x"}}, nil),
},
}
svc := NewWindsurfTierAccessService(repo)
first, _ := svc.Snapshot(context.Background())
// Pointer equality is the cache-hit signal: atomic.Pointer.Store fires
// only on rebuild, so a returned pointer identical to the prior call
// proves the build() path was skipped. We mutate the underlying repo
// to a state that would yield a different snapshot if the rebuild
// actually ran — the assertion below catches a regression in either
// direction (TTL gate broken, or sync.Once style misuse). The 60s
// default TTL is large enough that this test never sees an expiry
// during normal CI runs.
repo.accounts = nil
second, _ := svc.Snapshot(context.Background())
if first != second {
t.Fatal("expected cached pointer reuse")
}
}
func findTierRow(snap *WindsurfTierAccessSnapshot, model string) *WindsurfTierAccessRow {
for i := range snap.Rows {
if snap.Rows[i].Model == model {
return &snap.Rows[i]
}
}
return nil
}