sub2api/backend/internal/pkg/lspool/integration_test.go
win 6620b56b5a
Some checks failed
CI / test (push) Failing after 3s
CI / golangci-lint (push) Failing after 2s
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 3s
fix: encode ls model credits topic values as base64
2026-03-31 08:34:00 +08:00

865 lines
28 KiB
Go

package lspool
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func readConnectFrame(r io.Reader) ([]byte, error) {
header := make([]byte, 5)
if _, err := io.ReadFull(r, header); err != nil {
return nil, err
}
payloadLen := binary.BigEndian.Uint32(header[1:5])
payload := make([]byte, payloadLen)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
return payload, nil
}
func decodeProtoBytesField(data []byte, targetField int) []byte {
i := 0
for i < len(data) {
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0:
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
case 2:
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
if i+int(length) > len(data) {
return nil
}
if fieldNum == targetField {
return data[i : i+int(length)]
}
i += int(length)
case 1:
i += 8
case 5:
i += 4
default:
return nil
}
}
return nil
}
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
var values [][]byte
i := 0
for i < len(data) {
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0:
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
case 2:
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
if i+int(length) > len(data) {
return values
}
if fieldNum == targetField {
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
}
i += int(length)
case 1:
i += 8
case 5:
i += 4
default:
return values
}
}
return values
}
func decodeTopicRows(topic []byte) map[string]string {
rows := make(map[string]string)
for _, entry := range decodeProtoBytesFields(topic, 1) {
key := decodeProtoString(entry, 1)
row := decodeProtoBytesField(entry, 2)
rows[key] = decodeProtoString(row, 1)
}
return rows
}
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
t.Helper()
decoded, err := base64.StdEncoding.DecodeString(got)
require.NoError(t, err)
require.Equal(t, want, decoded)
}
// TestMockExtensionServerTokenInjection verifies the token injection flow:
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
func TestMockExtensionServerTokenInjection(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
// 1. Set token for an account
srv.SetToken("account-1", &TokenInfo{
AccessToken: "ya29.test-access-token",
RefreshToken: "1//test-refresh-token",
ExpiresAt: time.Now().Add(1 * time.Hour),
})
// 2. Verify token is stored
srv.mu.RLock()
info, ok := srv.tokens["account-1"]
srv.mu.RUnlock()
require.True(t, ok)
require.Equal(t, "ya29.test-access-token", info.AccessToken)
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
// The stream handler will block, so run in background and cancel after we confirm connection
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
// Read the first envelope frame (initial state)
header := make([]byte, 5)
n, readErr := resp.Body.Read(header)
if readErr == nil && n == 5 {
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
}
}
}
// TestMockExtensionServerCSRF verifies CSRF token validation
func TestMockExtensionServerCSRF(t *testing.T) {
csrf := "correct-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
// Wrong CSRF → 403
req, _ := http.NewRequest("POST", base, nil)
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 403, resp.StatusCode)
// Correct CSRF → 200
req2, _ := http.NewRequest("POST", base, nil)
req2.Header.Set("x-codeium-csrf-token", csrf)
req2.Header.Set("Content-Type", "application/proto")
resp2, err := http.DefaultClient.Do(req2)
require.NoError(t, err)
defer resp2.Body.Close()
require.Equal(t, 200, resp2.StatusCode)
}
// TestMockExtensionServerGetSecretValue verifies the fallback token path
func TestMockExtensionServerGetSecretValue(t *testing.T) {
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
// GetSecretValue should return the token
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
nil)
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
}
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
func TestOAuthTokenInfoProto(t *testing.T) {
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
// Verify fields are present by checking proto wire format
require.True(t, len(bin) > 0, "proto should not be empty")
// Field 1 (access_token): tag=0x0a, value="ya29.test"
require.Contains(t, string(bin), "ya29.test")
// Field 2 (token_type): tag=0x12, value="Bearer"
require.Contains(t, string(bin), "Bearer")
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
require.Contains(t, string(bin), "1//refresh")
// Without refresh_token
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
require.NotContains(t, string(binNoRefresh), "1//refresh")
}
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
future := time.Now().Add(2 * time.Hour)
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
// Zero expiry should default to ~1h
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
// They should be different lengths or content (different expiry timestamps)
// Both should be valid (non-empty)
require.True(t, len(bin) > 0)
require.True(t, len(binZero) > 0)
}
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
func TestUSSTopicWithOAuth(t *testing.T) {
expiry := time.Now().Add(1 * time.Hour)
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
require.True(t, len(topic) > 0)
// The topic should contain the sentinel key
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
}
func TestUSSTopicWithModelCredits(t *testing.T) {
available := int32(123)
minimum := int32(50)
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
require.True(t, len(topic) > 0)
require.Contains(t, string(topic), useAICreditsSentinelKey)
require.Contains(t, string(topic), availableCreditsSentinelKey)
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
rows := decodeTopicRows(topic)
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
}
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
// Drain the initial_state frame first.
_, err = readConnectFrame(resp.Body)
require.NoError(t, err)
available := int32(77)
minimum := int32(25)
srv.SetModelCredits("account-1", &ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
values := make(map[string]string, 3)
for len(values) < 3 {
frame, readErr := readConnectFrame(resp.Body)
require.NoError(t, readErr)
applied := decodeProtoBytesField(frame, 2)
require.NotEmpty(t, applied)
key := decodeProtoString(applied, 1)
row := decodeProtoBytesField(applied, 2)
values[key] = decodeProtoString(row, 1)
}
require.Contains(t, values, useAICreditsSentinelKey)
require.Contains(t, values, availableCreditsSentinelKey)
require.Contains(t, values, minimumCreditAmountForUsageKey)
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
}
// TestBuildInitialStateUpdate verifies the USS update wrapper
func TestBuildInitialStateUpdate(t *testing.T) {
topicData := buildEmptyTopic()
update := buildInitialStateUpdate(topicData)
// Should be a valid proto bytes field (field 1 = initial_state)
require.True(t, len(update) >= 0) // empty topic is valid
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
update2 := buildInitialStateUpdate(topicData2)
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
}
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
func TestPoolSetAccountTokenComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
expiry := time.Now().Add(1 * time.Hour)
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
srv.mu.RLock()
info := srv.tokens["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.Equal(t, "ya29.full-token", info.AccessToken)
require.Equal(t, "1//full-refresh", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
}
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
available := int32(77)
minimum := int32(25)
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
srv.mu.RLock()
info := srv.credits["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.True(t, info.UseAICredits)
require.NotNil(t, info.AvailableCredits)
require.Equal(t, available, *info.AvailableCredits)
require.NotNil(t, info.MinimumCreditAmountForUsage)
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
}
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
// Create a mock upstream that records what it receives
var receivedHeaders http.Header
var mu sync.Mutex
fallback := &recordingUpstreamWithCallback{}
fallback.onDo = func(req *http.Request) {
mu.Lock()
receivedHeaders = req.Header.Clone()
mu.Unlock()
}
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
}
upstream := NewLSPoolUpstream(pool, fallback)
// Non-streamGenerateContent request → should pass through to fallback
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
req.Header.Set("Authorization", "Bearer ya29.test")
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
req.Header.Set(useAICreditsHeader, "true")
req.Header.Set(availableCreditsHeader, "42")
req.Header.Set(minimumCreditAmountHeader, "50")
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
// Internal headers should never leak to the direct upstream.
mu.Lock()
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
mu.Unlock()
srv.mu.RLock()
tokenInfo := srv.tokens["1"]
creditsInfo := srv.credits["1"]
srv.mu.RUnlock()
require.NotNil(t, tokenInfo)
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
require.NotNil(t, creditsInfo)
require.True(t, creditsInfo.UseAICredits)
require.NotNil(t, creditsInfo.AvailableCredits)
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
}
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
body := `{
"model": "claude-sonnet-4-6",
"request": {
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}
}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "claude-sonnet-4-6", model)
require.Contains(t, prompt, "You are helpful")
require.Contains(t, prompt, "Hello")
require.Contains(t, prompt, "Hi there!")
require.Contains(t, prompt, "How are you?")
}
// TestExtractUsageFromTrajectory verifies token usage extraction
func TestExtractUsageFromTrajectory(t *testing.T) {
resp := `{
"trajectory": {
"steps": [{
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
"status": "CORTEX_STEP_STATUS_DONE",
"plannerResponse": {"response": "OK"},
"metadata": {
"modelUsage": {
"inputTokens": "150",
"outputTokens": "5"
}
}
}]
}
}`
usage := extractUsageFromTrajectory([]byte(resp))
require.NotNil(t, usage)
require.Equal(t, 150, usage["promptTokenCount"])
require.Equal(t, 5, usage["candidatesTokenCount"])
require.Equal(t, 155, usage["totalTokenCount"])
}
// TestSSEChunkFormat verifies the Gemini SSE output format
func TestSSEChunkFormat(t *testing.T) {
chunk := buildGeminiSSEChunk("Hello world")
require.True(t, len(chunk) > 0)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"text":"Hello world"`)
require.Contains(t, chunk, `"role":"model"`)
require.True(t, chunk[len(chunk)-2:] == "\n\n")
// Verify it's valid JSON after stripping "data: " prefix
jsonStr := chunk[len("data: ") : len(chunk)-2]
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err)
response := parsed["response"].(map[string]any)
candidates := response["candidates"].([]any)
require.Len(t, candidates, 1)
}
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
func TestSSEFinalChunkFormat(t *testing.T) {
usage := map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150,
}
chunk := buildGeminiSSEFinalChunk(usage)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"finishReason":"STOP"`)
require.Contains(t, chunk, `"usageMetadata"`)
}
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
var getCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
return
}
http.NotFound(w, r)
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
pr, pw := io.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
close(done)
}()
body, err := io.ReadAll(pr)
require.NoError(t, err)
<-done
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
require.Contains(t, string(body), "hello from ls")
}
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
func TestRequestHasToolsEdgeCases(t *testing.T) {
// null tools
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
// tools with empty function declarations
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
// deeply nested wrapped format
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
}
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
var getCalls atomic.Int32
var sendBodiesMu sync.Mutex
var sendBodies []map[string]any
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
var payload map[string]any
err := json.NewDecoder(r.Body).Decode(&payload)
require.NoError(t, err)
sendBodiesMu.Lock()
sendBodies = append(sendBodies, payload)
sendBodiesMu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
call := getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
text := "hello from ls"
if call > 1 {
text = "follow up from ls"
}
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Contains(t, string(body2), `"text":"follow up from ls"`)
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
require.Equal(t, int32(2), sendCalls.Load())
sendBodiesMu.Lock()
require.Len(t, sendBodies, 2)
firstSend := sendBodies[0]
sendBodiesMu.Unlock()
require.Equal(t, "cid-1", firstSend["cascadeId"])
require.Equal(t, false, firstSend["blocking"])
metadata, ok := firstSend["metadata"].(map[string]any)
require.True(t, ok)
require.Equal(t, "antigravity", metadata["ideName"])
require.Equal(t, "1.107.0", metadata["ideVersion"])
require.NotContains(t, firstSend, "clientType")
require.NotContains(t, firstSend, "messageOrigin")
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
require.True(t, ok)
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
require.Len(t, cascadeConfig, 1)
}
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "ok", string(body2))
require.Equal(t, 1, fallback.doCalls)
require.Equal(t, int32(1), startCalls.Load())
require.Equal(t, int32(1), sendCalls.Load())
}
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer downstream-a")
resp, err := upstream.Do(req, "", 42, 1)
require.Nil(t, resp)
require.ErrorIs(t, err, errLSModelMapPending)
require.Equal(t, int32(0), startCalls.Load())
require.Equal(t, int32(0), sendCalls.Load())
require.Equal(t, 0, fallback.doCalls)
}
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
type recordingUpstreamWithCallback struct {
recordingUpstream
onDo func(req *http.Request)
}
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
if r.onDo != nil {
r.onDo(req)
}
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
}