package lspool import ( "bytes" "context" "encoding/binary" "errors" "io" "net/http" "strings" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/stretchr/testify/require" ) func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) { env := buildLSEnv([]string{ "SSL_CERT_FILE=/custom/ca.pem", "SSL_CERT_DIR=/custom/certs", }, "/opt/antigravity", "") require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity") require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem") require.Contains(t, env, "SSL_CERT_DIR=/custom/certs") } func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) { env := buildLSEnv([]string{ "HTTPS_PROXY=http://old-proxy:8080", "HTTP_PROXY=http://old-proxy:8080", "ALL_PROXY=socks5://old-proxy:1080", "https_proxy=http://old-proxy:8080", "http_proxy=http://old-proxy:8080", "all_proxy=socks5://old-proxy:1080", }, "/opt/antigravity", "") require.Contains(t, env, "HTTPS_PROXY=") require.Contains(t, env, "HTTP_PROXY=") require.Contains(t, env, "ALL_PROXY=") require.Contains(t, env, "https_proxy=") require.Contains(t, env, "http_proxy=") require.Contains(t, env, "all_proxy=") } func TestShortAccountID(t *testing.T) { require.Equal(t, "9", shortAccountID("9")) require.Equal(t, "12345678", shortAccountID("12345678")) require.Equal(t, "12345678", shortAccountID("123456789")) } func TestFrameConnectMessage(t *testing.T) { framed := frameConnectMessage([]byte(`{"x":1}`)) require.Len(t, framed, 5+len(`{"x":1}`)) require.Equal(t, byte(0), framed[0]) require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5])) require.Equal(t, `{"x":1}`, string(framed[5:])) } func TestConnectEnvelope(t *testing.T) { payload := []byte("hello") env := connectEnvelope(0x00, payload) require.Len(t, env, 5+len(payload)) require.Equal(t, byte(0x00), env[0]) require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5])) require.Equal(t, "hello", string(env[5:])) } func TestUnwrapConnectEnvelope(t *testing.T) { payload := []byte("test data") env := connectEnvelope(0x00, payload) unwrapped := unwrapConnectEnvelope(env) require.Equal(t, payload, unwrapped) short := []byte{1, 2} require.Equal(t, short, unwrapConnectEnvelope(short)) } func TestExtractPromptAndModel(t *testing.T) { body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}` prompt, model := extractPromptAndModel([]byte(body)) require.Equal(t, "hello world", prompt) require.Equal(t, "gemini-2.5-pro", model) body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}` prompt2, _ := extractPromptAndModel([]byte(body2)) require.Equal(t, "test prompt", prompt2) } func TestResolveModelEnum(t *testing.T) { // Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash) require.True(t, resolveModelEnum("gemini-2.5-flash") > 0) require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0) require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0) require.True(t, resolveModelEnum("unknown-model") > 0) } func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) { cfg := buildCascadeConfig("models/gemini-2.5-flash") require.NotNil(t, cfg) plannerConfig, ok := cfg["plannerConfig"].(map[string]any) require.True(t, ok) requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) require.True(t, ok) require.NotEmpty(t, requestedModel["model"]) require.Len(t, plannerConfig, 1) } func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) { cfg := buildCascadeConfig("claude-sonnet-4-6") require.NotNil(t, cfg) plannerConfig, ok := cfg["plannerConfig"].(map[string]any) require.True(t, ok) requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) require.True(t, ok) require.NotEmpty(t, requestedModel["model"]) require.Len(t, plannerConfig, 1) } func TestDoNonStreamGeneratePassesThrough(t *testing.T) { fallback := &recordingUpstream{} upstream := NewLSPoolUpstream(&Pool{}, fallback) req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`))) resp, err := upstream.Do(req, "", 1, 1) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, 1, fallback.doCalls) } func TestExtractPlannerResponseText(t *testing.T) { resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[ {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}, {"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE", "plannerResponse":{"response":"Hello world"}} ]}}` text, generating, status := extractPlannerResponseText([]byte(resp)) require.Equal(t, "Hello world", text) require.False(t, generating) require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status) } func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) { resp := `{ "status":"CASCADE_RUN_STATUS_IDLE", "trajectory":{ "steps":[ {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"} ], "executorMetadata":{ "terminationReason":"ERROR", "errorDetails":{ "errorCode":429, "shortError":"Model quota reached", "details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s." } } } }` state := extractPlannerResponseState([]byte(resp)) require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status) require.False(t, state.Generating) require.Empty(t, state.Text) require.Contains(t, state.ErrorMessage, "Model quota reached") require.Contains(t, state.ErrorMessage, "quota will reset after") } func TestBuildGeminiSSEChunk(t *testing.T) { sse := buildGeminiSSEChunk("hello") require.Contains(t, sse, "data: ") require.Contains(t, sse, `"text":"hello"`) require.Contains(t, sse, `"role":"model"`) require.True(t, strings.HasSuffix(sse, "\n\n")) } func TestRequestHasTools(t *testing.T) { // Wrapped format with tools require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`))) // Direct format with tools require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`))) // No tools require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`))) // Empty tools array require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`))) } func TestCurrentLSStrategy(t *testing.T) { t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity") require.Equal(t, LSStrategyJSParity, CurrentLSStrategy()) t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown") require.Equal(t, LSStrategyDirect, CurrentLSStrategy()) } func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) { t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "") require.Equal(t, 5, parseLSReplicaCount()) t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3") require.Equal(t, 3, parseLSReplicaCount()) t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0") require.Equal(t, 5, parseLSReplicaCount()) } func TestPoolGetUsesStickyReplicaSlot(t *testing.T) { pool := &Pool{ config: Config{ReplicasPerAccount: 5}, instances: map[string][]*Instance{ "acc-1": { {AccountID: "acc-1", Replica: 0, healthy: true}, {AccountID: "acc-1", Replica: 1, healthy: true}, {AccountID: "acc-1", Replica: 2, healthy: true}, {AccountID: "acc-1", Replica: 3, healthy: true}, {AccountID: "acc-1", Replica: 4, healthy: true}, }, }, } routingKey := "acc-1:user-a:session-1" slot := replicaSlotIndex(routingKey, pool.replicaCount()) inst := pool.Get("acc-1", routingKey) require.NotNil(t, inst) require.Equal(t, slot, inst.Replica) } func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) { busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true} atomic.StoreInt64(&busy.inflight, 4) idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true} atomic.StoreInt64(&idle.inflight, 1) pool := &Pool{ config: Config{ReplicasPerAccount: 5}, instances: map[string][]*Instance{ "acc-1": {busy, idle}, }, } inst := pool.Get("acc-1", "") require.NotNil(t, inst) require.Equal(t, 1, inst.Replica) } func TestWaitForInstanceReadyProbesImmediately(t *testing.T) { startedAt := time.Now() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error { return nil }) require.NoError(t, err) require.Equal(t, 1, attempts) require.Less(t, time.Since(startedAt), 100*time.Millisecond) } func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() calls := 0 attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error { calls++ if calls < 3 { return errors.New("not ready") } return nil }) require.NoError(t, err) require.Equal(t, 3, attempts) require.Equal(t, 3, calls) } func TestDecideJSParityRoute(t *testing.T) { body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) parsed, err := parseGeminiRequest(body) require.NoError(t, err) decision := decideJSParityRoute(parsed, body) require.True(t, decision.UseLS) imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`) parsedImage, err := parseGeminiRequest(imageBody) require.NoError(t, err) decisionImage := decideJSParityRoute(parsedImage, imageBody) require.False(t, decisionImage.UseLS) require.Contains(t, strings.ToLower(decisionImage.Reason), "image") noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) parsedNoSession, err := parseGeminiRequest(noSessionBody) require.NoError(t, err) require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS) } func TestUserNamespacePrefersExplicitHeader(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "https://example.com", nil) require.NoError(t, err) req.Header.Set(userNamespaceHeader, "tenant-a") req.Header.Set("Authorization", "Bearer oauth-token") nsWithExplicit := userNamespace(req) require.NotEqual(t, "anon", nsWithExplicit) req.Header.Del(userNamespaceHeader) nsWithAuth := userNamespace(req) require.NotEqual(t, "anon", nsWithAuth) require.NotEqual(t, nsWithExplicit, nsWithAuth) } func TestConversationPrefixEqual(t *testing.T) { prefix := []geminiConversationTurn{ {Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}}, {Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}}, } full := append(cloneConversationTurns(prefix), geminiConversationTurn{ Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}}, }) require.True(t, conversationPrefixEqual(full, prefix)) require.False(t, conversationPrefixEqual(prefix, full)) } func TestSystemTextCompatible(t *testing.T) { require.True(t, systemTextCompatible("You are helpful", "")) require.True(t, systemTextCompatible("You are helpful", "You are helpful")) require.False(t, systemTextCompatible("", "You are helpful")) require.False(t, systemTextCompatible("You are helpful", "You are different")) } type recordingUpstream struct { doCalls int } func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { r.doCalls++ return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil } func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) { return r.Do(req, proxyURL, accountID, c) }