package lspool import ( "bytes" "context" "encoding/base64" "encoding/binary" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" ) func readConnectFrame(r io.Reader) ([]byte, error) { header := make([]byte, 5) if _, err := io.ReadFull(r, header); err != nil { return nil, err } payloadLen := binary.BigEndian.Uint32(header[1:5]) payload := make([]byte, payloadLen) if _, err := io.ReadFull(r, payload); err != nil { return nil, err } return payload, nil } func decodeProtoBytesField(data []byte, targetField int) []byte { i := 0 for i < len(data) { tag, n := binary.Uvarint(data[i:]) if n <= 0 { return nil } i += n fieldNum := int(tag >> 3) wireType := tag & 0x7 switch wireType { case 0: _, n = binary.Uvarint(data[i:]) if n <= 0 { return nil } i += n case 2: length, n := binary.Uvarint(data[i:]) if n <= 0 { return nil } i += n if i+int(length) > len(data) { return nil } if fieldNum == targetField { return data[i : i+int(length)] } i += int(length) case 1: i += 8 case 5: i += 4 default: return nil } } return nil } func decodeProtoBytesFields(data []byte, targetField int) [][]byte { var values [][]byte i := 0 for i < len(data) { tag, n := binary.Uvarint(data[i:]) if n <= 0 { return values } i += n fieldNum := int(tag >> 3) wireType := tag & 0x7 switch wireType { case 0: _, n = binary.Uvarint(data[i:]) if n <= 0 { return values } i += n case 2: length, n := binary.Uvarint(data[i:]) if n <= 0 { return values } i += n if i+int(length) > len(data) { return values } if fieldNum == targetField { values = append(values, append([]byte(nil), data[i:i+int(length)]...)) } i += int(length) case 1: i += 8 case 5: i += 4 default: return values } } return values } func decodeTopicRows(topic []byte) map[string]string { rows := make(map[string]string) for _, entry := range decodeProtoBytesFields(topic, 1) { key := decodeProtoString(entry, 1) row := decodeProtoBytesField(entry, 2) rows[key] = decodeProtoString(row, 1) } return rows } func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) { t.Helper() decoded, err := base64.StdEncoding.DecodeString(got) require.NoError(t, err) require.Equal(t, want, decoded) } // TestMockExtensionServerTokenInjection verifies the token injection flow: // Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo func TestMockExtensionServerTokenInjection(t *testing.T) { csrf := "test-csrf-token" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() // 1. Set token for an account srv.SetToken("account-1", &TokenInfo{ AccessToken: "ya29.test-access-token", RefreshToken: "1//test-refresh-token", ExpiresAt: time.Now().Add(1 * time.Hour), }) // 2. Verify token is stored srv.mu.RLock() info, ok := srv.tokens["account-1"] srv.mu.RUnlock() require.True(t, ok) require.Equal(t, "ya29.test-access-token", info.AccessToken) require.Equal(t, "1//test-refresh-token", info.RefreshToken) require.False(t, info.ExpiresAt.IsZero()) // 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server) req, _ := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth")))) req.Header.Set("x-codeium-csrf-token", csrf) req.Header.Set("Content-Type", "application/connect+proto") // The stream handler will block, so run in background and cancel after we confirm connection ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() req = req.WithContext(ctx) client := &http.Client{} resp, err := client.Do(req) if err == nil { defer resp.Body.Close() require.Equal(t, 200, resp.StatusCode) require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type")) // Read the first envelope frame (initial state) header := make([]byte, 5) n, readErr := resp.Body.Read(header) if readErr == nil && n == 5 { require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)") t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5]) } } } // TestMockExtensionServerCSRF verifies CSRF token validation func TestMockExtensionServerCSRF(t *testing.T) { csrf := "correct-csrf" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port()) // Wrong CSRF → 403 req, _ := http.NewRequest("POST", base, nil) req.Header.Set("x-codeium-csrf-token", "wrong-csrf") req.Header.Set("Content-Type", "application/proto") resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, 403, resp.StatusCode) // Correct CSRF → 200 req2, _ := http.NewRequest("POST", base, nil) req2.Header.Set("x-codeium-csrf-token", csrf) req2.Header.Set("Content-Type", "application/proto") resp2, err := http.DefaultClient.Do(req2) require.NoError(t, err) defer resp2.Body.Close() require.Equal(t, 200, resp2.StatusCode) } // TestMockExtensionServerGetSecretValue verifies the fallback token path func TestMockExtensionServerGetSecretValue(t *testing.T) { csrf := "test-csrf" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"}) // GetSecretValue should return the token req, _ := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()), nil) req.Header.Set("x-codeium-csrf-token", csrf) req.Header.Set("Content-Type", "application/proto") resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, 200, resp.StatusCode) } // TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format func TestOAuthTokenInfoProto(t *testing.T) { expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC) bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry) // Verify fields are present by checking proto wire format require.True(t, len(bin) > 0, "proto should not be empty") // Field 1 (access_token): tag=0x0a, value="ya29.test" require.Contains(t, string(bin), "ya29.test") // Field 2 (token_type): tag=0x12, value="Bearer" require.Contains(t, string(bin), "Bearer") // Field 3 (refresh_token): tag=0x1a, value="1//refresh" require.Contains(t, string(bin), "1//refresh") // Without refresh_token binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry) require.NotContains(t, string(binNoRefresh), "1//refresh") } // TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded func TestOAuthTokenInfoWithRealExpiry(t *testing.T) { future := time.Now().Add(2 * time.Hour) bin := buildOAuthTokenInfoBinary("token", "refresh", future) // Zero expiry should default to ~1h binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{}) // They should be different lengths or content (different expiry timestamps) // Both should be valid (non-empty) require.True(t, len(bin) > 0) require.True(t, len(binZero) > 0) } // TestUSSTopicWithOAuth verifies the full USS topic proto structure func TestUSSTopicWithOAuth(t *testing.T) { expiry := time.Now().Add(1 * time.Hour) topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry) require.True(t, len(topic) > 0) // The topic should contain the sentinel key require.Contains(t, string(topic), "oauthTokenInfoSentinelKey") } func TestUSSTopicWithModelCredits(t *testing.T) { available := int32(123) minimum := int32(50) topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{ UseAICredits: true, AvailableCredits: &available, MinimumCreditAmountForUsage: &minimum, }) require.True(t, len(topic) > 0) require.Contains(t, string(topic), useAICreditsSentinelKey) require.Contains(t, string(topic), availableCreditsSentinelKey) require.Contains(t, string(topic), minimumCreditAmountForUsageKey) rows := decodeTopicRows(topic) requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true)) requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available)) requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum)) } func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) { csrf := "test-csrf-token" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() srv.SetModelCredits("account-1", &ModelCreditsInfo{}) req, _ := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits")))) req.Header.Set("x-codeium-csrf-token", csrf) req.Header.Set("Content-Type", "application/connect+proto") ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, 200, resp.StatusCode) // Drain the initial_state frame first. _, err = readConnectFrame(resp.Body) require.NoError(t, err) available := int32(77) minimum := int32(25) srv.SetModelCredits("account-1", &ModelCreditsInfo{ UseAICredits: true, AvailableCredits: &available, MinimumCreditAmountForUsage: &minimum, }) values := make(map[string]string, 3) for len(values) < 3 { frame, readErr := readConnectFrame(resp.Body) require.NoError(t, readErr) applied := decodeProtoBytesField(frame, 2) require.NotEmpty(t, applied) key := decodeProtoString(applied, 1) row := decodeProtoBytesField(applied, 2) values[key] = decodeProtoString(row, 1) } require.Contains(t, values, useAICreditsSentinelKey) require.Contains(t, values, availableCreditsSentinelKey) require.Contains(t, values, minimumCreditAmountForUsageKey) requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true)) requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available)) requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum)) } // TestBuildInitialStateUpdate verifies the USS update wrapper func TestBuildInitialStateUpdate(t *testing.T) { topicData := buildEmptyTopic() update := buildInitialStateUpdate(topicData) // Should be a valid proto bytes field (field 1 = initial_state) require.True(t, len(update) >= 0) // empty topic is valid topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour)) update2 := buildInitialStateUpdate(topicData2) require.True(t, len(update2) > len(update), "non-empty topic should produce larger update") } // TestPoolSetAccountTokenComplete verifies pool accepts full credential set func TestPoolSetAccountTokenComplete(t *testing.T) { csrf := "pool-csrf" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() pool := &Pool{ config: DefaultConfig(), instances: make(map[string][]*Instance), extServer: srv, ctx: ctx, cancel: cancel, } expiry := time.Now().Add(1 * time.Hour) pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry) srv.mu.RLock() info := srv.tokens["acc-1"] srv.mu.RUnlock() require.NotNil(t, info) require.Equal(t, "ya29.full-token", info.AccessToken) require.Equal(t, "1//full-refresh", info.RefreshToken) require.False(t, info.ExpiresAt.IsZero()) require.WithinDuration(t, expiry, info.ExpiresAt, time.Second) } func TestPoolSetAccountModelCreditsComplete(t *testing.T) { csrf := "pool-csrf" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() pool := &Pool{ config: DefaultConfig(), instances: make(map[string][]*Instance), extServer: srv, ctx: ctx, cancel: cancel, } available := int32(77) minimum := int32(25) pool.SetAccountModelCredits("acc-1", true, &available, &minimum) srv.mu.RLock() info := srv.credits["acc-1"] srv.mu.RUnlock() require.NotNil(t, info) require.True(t, info.UseAICredits) require.NotNil(t, info.AvailableCredits) require.Equal(t, available, *info.AvailableCredits) require.NotNil(t, info.MinimumCreditAmountForUsage) require.Equal(t, minimum, *info.MinimumCreditAmountForUsage) } // TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped. func TestUpstreamAdapterExtractsCredentials(t *testing.T) { // Create a mock upstream that records what it receives var receivedHeaders http.Header var mu sync.Mutex fallback := &recordingUpstreamWithCallback{} fallback.onDo = func(req *http.Request) { mu.Lock() receivedHeaders = req.Header.Clone() mu.Unlock() } csrf := "test-csrf" srv, err := NewMockExtensionServer(csrf) require.NoError(t, err) defer srv.Close() pool := &Pool{ config: DefaultConfig(), instances: make(map[string][]*Instance), extServer: srv, } upstream := NewLSPoolUpstream(pool, fallback) // Non-streamGenerateContent request → should pass through to fallback req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil) req.Header.Set("Authorization", "Bearer ya29.test") req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh") req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z") req.Header.Set(useAICreditsHeader, "true") req.Header.Set(availableCreditsHeader, "42") req.Header.Set(minimumCreditAmountHeader, "50") resp, err := upstream.Do(req, "", 1, 1) require.NoError(t, err) require.NotNil(t, resp) // Internal headers should never leak to the direct upstream. mu.Lock() require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token")) require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry")) require.Empty(t, receivedHeaders.Get(useAICreditsHeader)) require.Empty(t, receivedHeaders.Get(availableCreditsHeader)) require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader)) mu.Unlock() srv.mu.RLock() tokenInfo := srv.tokens["1"] creditsInfo := srv.credits["1"] srv.mu.RUnlock() require.NotNil(t, tokenInfo) require.Equal(t, "ya29.test", tokenInfo.AccessToken) require.NotNil(t, creditsInfo) require.True(t, creditsInfo.UseAICredits) require.NotNil(t, creditsInfo.AvailableCredits) require.Equal(t, int32(42), *creditsInfo.AvailableCredits) require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage) require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage) } // TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction func TestExtractPromptAndModelMultiTurn(t *testing.T) { body := `{ "model": "claude-sonnet-4-6", "request": { "systemInstruction": {"parts": [{"text": "You are helpful"}]}, "contents": [ {"role": "user", "parts": [{"text": "Hello"}]}, {"role": "model", "parts": [{"text": "Hi there!"}]}, {"role": "user", "parts": [{"text": "How are you?"}]} ] } }` prompt, model := extractPromptAndModel([]byte(body)) require.Equal(t, "claude-sonnet-4-6", model) require.Contains(t, prompt, "You are helpful") require.Contains(t, prompt, "Hello") require.Contains(t, prompt, "Hi there!") require.Contains(t, prompt, "How are you?") } // TestExtractUsageFromTrajectory verifies token usage extraction func TestExtractUsageFromTrajectory(t *testing.T) { resp := `{ "trajectory": { "steps": [{ "type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE", "status": "CORTEX_STEP_STATUS_DONE", "plannerResponse": {"response": "OK"}, "metadata": { "modelUsage": { "inputTokens": "150", "outputTokens": "5" } } }] } }` usage := extractUsageFromTrajectory([]byte(resp)) require.NotNil(t, usage) require.Equal(t, 150, usage["promptTokenCount"]) require.Equal(t, 5, usage["candidatesTokenCount"]) require.Equal(t, 155, usage["totalTokenCount"]) } // TestSSEChunkFormat verifies the Gemini SSE output format func TestSSEChunkFormat(t *testing.T) { chunk := buildGeminiSSEChunk("Hello world") require.True(t, len(chunk) > 0) require.Contains(t, chunk, "data: ") require.Contains(t, chunk, `"text":"Hello world"`) require.Contains(t, chunk, `"role":"model"`) require.True(t, chunk[len(chunk)-2:] == "\n\n") // Verify it's valid JSON after stripping "data: " prefix jsonStr := chunk[len("data: ") : len(chunk)-2] var parsed map[string]any err := json.Unmarshal([]byte(jsonStr), &parsed) require.NoError(t, err) response := parsed["response"].(map[string]any) candidates := response["candidates"].([]any) require.Len(t, candidates, 1) } // TestSSEFinalChunkFormat verifies the final SSE chunk with usage func TestSSEFinalChunkFormat(t *testing.T) { usage := map[string]any{ "promptTokenCount": 100, "candidatesTokenCount": 50, "totalTokenCount": 150, } chunk := buildGeminiSSEFinalChunk(usage) require.Contains(t, chunk, "data: ") require.Contains(t, chunk, `"finishReason":"STOP"`) require.Contains(t, chunk, `"usageMetadata"`) } func TestStreamCascadeResponsePollsImmediately(t *testing.T) { var getCalls atomic.Int32 server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") { getCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) return } http.NotFound(w, r) })) defer server.Close() inst := &Instance{ AccountID: "42", CSRF: "test-csrf", Address: strings.TrimPrefix(server.URL, "https://"), client: server.Client(), healthy: true, lastUsed: time.Now(), } upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{}) pr, pw := io.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() done := make(chan struct{}) go func() { upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil) close(done) }() body, err := io.ReadAll(pr) require.NoError(t, err) <-done require.GreaterOrEqual(t, getCalls.Load(), int32(1)) require.Contains(t, string(body), "hello from ls") } // TestRequestHasToolsEdgeCases verifies tool detection edge cases func TestRequestHasToolsEdgeCases(t *testing.T) { // null tools require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`))) // tools with empty function declarations require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`))) // deeply nested wrapped format require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`))) } func TestJSParityRouteReusesCascadeSession(t *testing.T) { t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) var startCalls atomic.Int32 var sendCalls atomic.Int32 var getCalls atomic.Int32 var sendBodiesMu sync.Mutex var sendBodies []map[string]any server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) switch { case strings.HasSuffix(r.URL.Path, "/StartCascade"): startCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): sendCalls.Add(1) var payload map[string]any err := json.NewDecoder(r.Body).Decode(&payload) require.NoError(t, err) sendBodiesMu.Lock() sendBodies = append(sendBodies, payload) sendBodiesMu.Unlock() w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"queued":false}`)) case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): call := getCalls.Add(1) w.Header().Set("Content-Type", "application/json") text := "hello from ls" if call > 1 { text = "follow up from ls" } _, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text))) default: http.NotFound(w, r) } })) defer server.Close() inst := &Instance{ AccountID: "42", CSRF: "test-csrf", Address: strings.TrimPrefix(server.URL, "https://"), client: server.Client(), healthy: true, lastUsed: time.Now(), } inst.SetModelMappingReady(true) pool := &Pool{ config: Config{ReplicasPerAccount: 1}, instances: map[string][]*Instance{"42": []*Instance{inst}}, } upstream := NewLSPoolUpstream(pool, &recordingUpstream{}) req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) require.NoError(t, err) req1.Header.Set("Authorization", "Bearer downstream-a") resp1, err := upstream.Do(req1, "", 42, 1) require.NoError(t, err) body1, err := io.ReadAll(resp1.Body) require.NoError(t, err) require.Contains(t, string(body1), `"text":"hello from ls"`) req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) require.NoError(t, err) req2.Header.Set("Authorization", "Bearer downstream-a") resp2, err := upstream.Do(req2, "", 42, 1) require.NoError(t, err) body2, err := io.ReadAll(resp2.Body) require.NoError(t, err) require.Contains(t, string(body2), `"text":"follow up from ls"`) require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript") require.Equal(t, int32(2), sendCalls.Load()) sendBodiesMu.Lock() require.Len(t, sendBodies, 2) firstSend := sendBodies[0] sendBodiesMu.Unlock() require.Equal(t, "cid-1", firstSend["cascadeId"]) require.Equal(t, false, firstSend["blocking"]) metadata, ok := firstSend["metadata"].(map[string]any) require.True(t, ok) require.Equal(t, "antigravity", metadata["ideName"]) require.Equal(t, "1.107.0", metadata["ideVersion"]) require.NotContains(t, firstSend, "clientType") require.NotContains(t, firstSend, "messageOrigin") cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any) require.True(t, ok) plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any) require.True(t, ok) requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) require.True(t, ok) require.NotEmpty(t, requestedModel["model"]) require.Len(t, plannerConfig, 1) require.Len(t, cascadeConfig, 1) } func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) { t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) var startCalls atomic.Int32 var sendCalls atomic.Int32 server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) switch { case strings.HasSuffix(r.URL.Path, "/StartCascade"): startCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): sendCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"queued":false}`)) case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) default: http.NotFound(w, r) } })) defer server.Close() inst := &Instance{ AccountID: "42", CSRF: "test-csrf", Address: strings.TrimPrefix(server.URL, "https://"), client: server.Client(), healthy: true, lastUsed: time.Now(), } inst.SetModelMappingReady(true) fallback := &recordingUpstream{} pool := &Pool{ config: Config{ReplicasPerAccount: 1}, instances: map[string][]*Instance{"42": []*Instance{inst}}, } upstream := NewLSPoolUpstream(pool, fallback) req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) require.NoError(t, err) req1.Header.Set("Authorization", "Bearer downstream-a") resp1, err := upstream.Do(req1, "", 42, 1) require.NoError(t, err) body1, err := io.ReadAll(resp1.Body) require.NoError(t, err) require.Contains(t, string(body1), `"text":"hello from ls"`) req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) require.NoError(t, err) req2.Header.Set("Authorization", "Bearer downstream-a") resp2, err := upstream.Do(req2, "", 42, 1) require.NoError(t, err) body2, err := io.ReadAll(resp2.Body) require.NoError(t, err) require.Equal(t, "ok", string(body2)) require.Equal(t, 1, fallback.doCalls) require.Equal(t, int32(1), startCalls.Load()) require.Equal(t, int32(1), sendCalls.Load()) } func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) { t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) var startCalls atomic.Int32 var sendCalls atomic.Int32 server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.HasSuffix(r.URL.Path, "/StartCascade"): startCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): sendCalls.Add(1) w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"queued":false}`)) default: http.NotFound(w, r) } })) defer server.Close() inst := &Instance{ AccountID: "42", CSRF: "test-csrf", Address: strings.TrimPrefix(server.URL, "https://"), client: server.Client(), healthy: true, lastUsed: time.Now(), } fallback := &recordingUpstream{} pool := &Pool{ config: Config{ReplicasPerAccount: 1}, instances: map[string][]*Instance{"42": []*Instance{inst}}, } upstream := NewLSPoolUpstream(pool, fallback) reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody)) require.NoError(t, err) req.Header.Set("Authorization", "Bearer downstream-a") resp, err := upstream.Do(req, "", 42, 1) require.Nil(t, resp) require.ErrorIs(t, err, errLSModelMapPending) require.Equal(t, int32(0), startCalls.Load()) require.Equal(t, int32(0), sendCalls.Load()) require.Equal(t, 0, fallback.doCalls) } // recordingUpstreamWithCallback extends the base recordingUpstream with a callback type recordingUpstreamWithCallback struct { recordingUpstream onDo func(req *http.Request) } func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) { if r.onDo != nil { r.onDo(req) } return r.recordingUpstream.Do(req, proxyURL, accountID, c) }