package service import ( "context" "errors" "sync" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" "github.com/stretchr/testify/require" ) type fakeLSBootstrapAccountReader struct { mu sync.Mutex accounts []Account err error platforms []string } func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) { f.mu.Lock() f.platforms = append(f.platforms, platform) accounts := append([]Account(nil), f.accounts...) err := f.err f.mu.Unlock() return accounts, err } type fakeLSPoolBackend struct { mu sync.Mutex tokenCalls map[string]fakeLSPoolTokenCall creditCalls map[string]fakeLSPoolCreditCall getCalls []fakeLSPoolGetCall getErrs map[string]error } type fakeLSPoolTokenCall struct { AccessToken string RefreshToken string ExpiresAt time.Time } type fakeLSPoolCreditCall struct { UseAICredits bool AvailableCredits *int32 MinimumCreditAmount *int32 } type fakeLSPoolGetCall struct { AccountID string RoutingKey string ProxyURL string } func newFakeLSPoolBackend() *fakeLSPoolBackend { return &fakeLSPoolBackend{ tokenCalls: make(map[string]fakeLSPoolTokenCall), creditCalls: make(map[string]fakeLSPoolCreditCall), getErrs: make(map[string]error), } } func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) { rawProxy := "" if len(proxyURL) > 0 { rawProxy = proxyURL[0] } f.mu.Lock() defer f.mu.Unlock() f.getCalls = append(f.getCalls, fakeLSPoolGetCall{ AccountID: accountID, RoutingKey: routingKey, ProxyURL: rawProxy, }) if err := f.getErrs[accountID]; err != nil { return nil, err } return &lspool.Instance{AccountID: accountID}, nil } func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { f.mu.Lock() defer f.mu.Unlock() f.tokenCalls[accountID] = fakeLSPoolTokenCall{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresAt: expiresAt, } } func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { f.mu.Lock() defer f.mu.Unlock() f.creditCalls[accountID] = fakeLSPoolCreditCall{ UseAICredits: useAICredits, AvailableCredits: copyInt32Ptr(availableCredits), MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage), } } func (f *fakeLSPoolBackend) Stats() map[string]any { return nil } func (f *fakeLSPoolBackend) Close() {} func copyInt32Ptr(v *int32) *int32 { if v == nil { return nil } cp := *v return &cp } func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) { expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) expiredAt := time.Now().Add(-2 * time.Hour) reader := &fakeLSBootstrapAccountReader{ accounts: []Account{ { ID: 101, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Credentials: map[string]any{ "access_token": "token-101", "refresh_token": "refresh-101", "expires_at": expiresAt.Format(time.RFC3339), "project_id": "proj-101", }, Extra: map[string]any{ "allow_overages": true, "ai_credits": []any{ map[string]any{ "credit_type": "GOOGLE_ONE_AI", "amount": 120, "minimum_balance": 55, }, }, }, Proxy: &Proxy{ Protocol: "socks5h", Host: "127.0.0.1", Port: 1080, Username: "alice", Password: "secret", }, }, { ID: 102, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: false, Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"}, }, { ID: 103, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"access_token": "token-103"}, }, { ID: 104, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, AutoPauseOnExpired: true, ExpiresAt: &expiredAt, Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"}, }, { ID: 106, Platform: PlatformAntigravity, Type: AccountTypeUpstream, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"}, }, { ID: 105, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"access_token": "token-105"}, }, }, } backend := newFakeLSPoolBackend() svc := NewLSPoolBootstrapService(reader, backend, &config.Config{ Gateway: config.GatewayConfig{ AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3}, }, }) svc.bootstrap(context.Background()) require.Equal(t, []string{PlatformAntigravity}, reader.platforms) require.Len(t, backend.getCalls, 1) require.Equal(t, fakeLSPoolGetCall{ AccountID: "101", RoutingKey: "", ProxyURL: "socks5h://alice:secret@127.0.0.1:1080", }, backend.getCalls[0]) tokenCall, ok := backend.tokenCalls["101"] require.True(t, ok) require.Equal(t, "token-101", tokenCall.AccessToken) require.Equal(t, "refresh-101", tokenCall.RefreshToken) require.Equal(t, expiresAt, tokenCall.ExpiresAt) creditCall, ok := backend.creditCalls["101"] require.True(t, ok) require.True(t, creditCall.UseAICredits) require.NotNil(t, creditCall.AvailableCredits) require.Equal(t, int32(120), *creditCall.AvailableCredits) require.NotNil(t, creditCall.MinimumCreditAmount) require.Equal(t, int32(55), *creditCall.MinimumCreditAmount) require.NotContains(t, backend.tokenCalls, "102") require.NotContains(t, backend.tokenCalls, "103") require.NotContains(t, backend.tokenCalls, "104") require.NotContains(t, backend.tokenCalls, "106") } func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) { reader := &fakeLSBootstrapAccountReader{ accounts: []Account{ { ID: 201, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"}, }, { ID: 202, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"}, }, }, } backend := newFakeLSPoolBackend() backend.getErrs["201"] = errors.New("create failed") svc := NewLSPoolBootstrapService(reader, backend, &config.Config{}) svc.bootstrap(context.Background()) require.Len(t, backend.getCalls, 2) require.Contains(t, backend.tokenCalls, "201") require.Contains(t, backend.tokenCalls, "202") }