865 lines
28 KiB
Go
865 lines
28 KiB
Go
package lspool
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func readConnectFrame(r io.Reader) ([]byte, error) {
|
|
header := make([]byte, 5)
|
|
if _, err := io.ReadFull(r, header); err != nil {
|
|
return nil, err
|
|
}
|
|
payloadLen := binary.BigEndian.Uint32(header[1:5])
|
|
payload := make([]byte, payloadLen)
|
|
if _, err := io.ReadFull(r, payload); err != nil {
|
|
return nil, err
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func decodeProtoBytesField(data []byte, targetField int) []byte {
|
|
i := 0
|
|
for i < len(data) {
|
|
tag, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return nil
|
|
}
|
|
i += n
|
|
fieldNum := int(tag >> 3)
|
|
wireType := tag & 0x7
|
|
switch wireType {
|
|
case 0:
|
|
_, n = binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return nil
|
|
}
|
|
i += n
|
|
case 2:
|
|
length, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return nil
|
|
}
|
|
i += n
|
|
if i+int(length) > len(data) {
|
|
return nil
|
|
}
|
|
if fieldNum == targetField {
|
|
return data[i : i+int(length)]
|
|
}
|
|
i += int(length)
|
|
case 1:
|
|
i += 8
|
|
case 5:
|
|
i += 4
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
|
|
var values [][]byte
|
|
i := 0
|
|
for i < len(data) {
|
|
tag, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return values
|
|
}
|
|
i += n
|
|
fieldNum := int(tag >> 3)
|
|
wireType := tag & 0x7
|
|
switch wireType {
|
|
case 0:
|
|
_, n = binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return values
|
|
}
|
|
i += n
|
|
case 2:
|
|
length, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return values
|
|
}
|
|
i += n
|
|
if i+int(length) > len(data) {
|
|
return values
|
|
}
|
|
if fieldNum == targetField {
|
|
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
|
|
}
|
|
i += int(length)
|
|
case 1:
|
|
i += 8
|
|
case 5:
|
|
i += 4
|
|
default:
|
|
return values
|
|
}
|
|
}
|
|
return values
|
|
}
|
|
|
|
func decodeTopicRows(topic []byte) map[string]string {
|
|
rows := make(map[string]string)
|
|
for _, entry := range decodeProtoBytesFields(topic, 1) {
|
|
key := decodeProtoString(entry, 1)
|
|
row := decodeProtoBytesField(entry, 2)
|
|
rows[key] = decodeProtoString(row, 1)
|
|
}
|
|
return rows
|
|
}
|
|
|
|
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
|
|
t.Helper()
|
|
decoded, err := base64.StdEncoding.DecodeString(got)
|
|
require.NoError(t, err)
|
|
require.Equal(t, want, decoded)
|
|
}
|
|
|
|
// TestMockExtensionServerTokenInjection verifies the token injection flow:
|
|
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
|
|
func TestMockExtensionServerTokenInjection(t *testing.T) {
|
|
csrf := "test-csrf-token"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
// 1. Set token for an account
|
|
srv.SetToken("account-1", &TokenInfo{
|
|
AccessToken: "ya29.test-access-token",
|
|
RefreshToken: "1//test-refresh-token",
|
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
|
})
|
|
|
|
// 2. Verify token is stored
|
|
srv.mu.RLock()
|
|
info, ok := srv.tokens["account-1"]
|
|
srv.mu.RUnlock()
|
|
require.True(t, ok)
|
|
require.Equal(t, "ya29.test-access-token", info.AccessToken)
|
|
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
|
|
require.False(t, info.ExpiresAt.IsZero())
|
|
|
|
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
|
|
req, _ := http.NewRequest("POST",
|
|
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
|
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
|
|
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
req.Header.Set("Content-Type", "application/connect+proto")
|
|
|
|
// The stream handler will block, so run in background and cancel after we confirm connection
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
req = req.WithContext(ctx)
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err == nil {
|
|
defer resp.Body.Close()
|
|
require.Equal(t, 200, resp.StatusCode)
|
|
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
|
|
|
|
// Read the first envelope frame (initial state)
|
|
header := make([]byte, 5)
|
|
n, readErr := resp.Body.Read(header)
|
|
if readErr == nil && n == 5 {
|
|
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
|
|
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMockExtensionServerCSRF verifies CSRF token validation
|
|
func TestMockExtensionServerCSRF(t *testing.T) {
|
|
csrf := "correct-csrf"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
|
|
|
|
// Wrong CSRF → 403
|
|
req, _ := http.NewRequest("POST", base, nil)
|
|
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
|
|
req.Header.Set("Content-Type", "application/proto")
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
require.Equal(t, 403, resp.StatusCode)
|
|
|
|
// Correct CSRF → 200
|
|
req2, _ := http.NewRequest("POST", base, nil)
|
|
req2.Header.Set("x-codeium-csrf-token", csrf)
|
|
req2.Header.Set("Content-Type", "application/proto")
|
|
resp2, err := http.DefaultClient.Do(req2)
|
|
require.NoError(t, err)
|
|
defer resp2.Body.Close()
|
|
require.Equal(t, 200, resp2.StatusCode)
|
|
}
|
|
|
|
// TestMockExtensionServerGetSecretValue verifies the fallback token path
|
|
func TestMockExtensionServerGetSecretValue(t *testing.T) {
|
|
csrf := "test-csrf"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
|
|
|
|
// GetSecretValue should return the token
|
|
req, _ := http.NewRequest("POST",
|
|
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
|
|
nil)
|
|
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
req.Header.Set("Content-Type", "application/proto")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
require.Equal(t, 200, resp.StatusCode)
|
|
}
|
|
|
|
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
|
|
func TestOAuthTokenInfoProto(t *testing.T) {
|
|
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
|
|
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
|
|
|
|
// Verify fields are present by checking proto wire format
|
|
require.True(t, len(bin) > 0, "proto should not be empty")
|
|
|
|
// Field 1 (access_token): tag=0x0a, value="ya29.test"
|
|
require.Contains(t, string(bin), "ya29.test")
|
|
// Field 2 (token_type): tag=0x12, value="Bearer"
|
|
require.Contains(t, string(bin), "Bearer")
|
|
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
|
|
require.Contains(t, string(bin), "1//refresh")
|
|
|
|
// Without refresh_token
|
|
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
|
|
require.NotContains(t, string(binNoRefresh), "1//refresh")
|
|
}
|
|
|
|
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
|
|
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
|
|
future := time.Now().Add(2 * time.Hour)
|
|
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
|
|
|
|
// Zero expiry should default to ~1h
|
|
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
|
|
|
|
// They should be different lengths or content (different expiry timestamps)
|
|
// Both should be valid (non-empty)
|
|
require.True(t, len(bin) > 0)
|
|
require.True(t, len(binZero) > 0)
|
|
}
|
|
|
|
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
|
|
func TestUSSTopicWithOAuth(t *testing.T) {
|
|
expiry := time.Now().Add(1 * time.Hour)
|
|
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
|
|
|
|
require.True(t, len(topic) > 0)
|
|
// The topic should contain the sentinel key
|
|
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
|
|
}
|
|
|
|
func TestUSSTopicWithModelCredits(t *testing.T) {
|
|
available := int32(123)
|
|
minimum := int32(50)
|
|
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
|
|
UseAICredits: true,
|
|
AvailableCredits: &available,
|
|
MinimumCreditAmountForUsage: &minimum,
|
|
})
|
|
|
|
require.True(t, len(topic) > 0)
|
|
require.Contains(t, string(topic), useAICreditsSentinelKey)
|
|
require.Contains(t, string(topic), availableCreditsSentinelKey)
|
|
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
|
|
|
|
rows := decodeTopicRows(topic)
|
|
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
|
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
|
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
|
}
|
|
|
|
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
|
|
csrf := "test-csrf-token"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
|
|
|
|
req, _ := http.NewRequest("POST",
|
|
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
|
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
|
|
req.Header.Set("x-codeium-csrf-token", csrf)
|
|
req.Header.Set("Content-Type", "application/connect+proto")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
req = req.WithContext(ctx)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
require.Equal(t, 200, resp.StatusCode)
|
|
|
|
// Drain the initial_state frame first.
|
|
_, err = readConnectFrame(resp.Body)
|
|
require.NoError(t, err)
|
|
|
|
available := int32(77)
|
|
minimum := int32(25)
|
|
srv.SetModelCredits("account-1", &ModelCreditsInfo{
|
|
UseAICredits: true,
|
|
AvailableCredits: &available,
|
|
MinimumCreditAmountForUsage: &minimum,
|
|
})
|
|
|
|
values := make(map[string]string, 3)
|
|
for len(values) < 3 {
|
|
frame, readErr := readConnectFrame(resp.Body)
|
|
require.NoError(t, readErr)
|
|
applied := decodeProtoBytesField(frame, 2)
|
|
require.NotEmpty(t, applied)
|
|
key := decodeProtoString(applied, 1)
|
|
row := decodeProtoBytesField(applied, 2)
|
|
values[key] = decodeProtoString(row, 1)
|
|
}
|
|
|
|
require.Contains(t, values, useAICreditsSentinelKey)
|
|
require.Contains(t, values, availableCreditsSentinelKey)
|
|
require.Contains(t, values, minimumCreditAmountForUsageKey)
|
|
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
|
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
|
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
|
}
|
|
|
|
// TestBuildInitialStateUpdate verifies the USS update wrapper
|
|
func TestBuildInitialStateUpdate(t *testing.T) {
|
|
topicData := buildEmptyTopic()
|
|
update := buildInitialStateUpdate(topicData)
|
|
// Should be a valid proto bytes field (field 1 = initial_state)
|
|
require.True(t, len(update) >= 0) // empty topic is valid
|
|
|
|
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
|
|
update2 := buildInitialStateUpdate(topicData2)
|
|
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
|
|
}
|
|
|
|
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
|
|
func TestPoolSetAccountTokenComplete(t *testing.T) {
|
|
csrf := "pool-csrf"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
pool := &Pool{
|
|
config: DefaultConfig(),
|
|
instances: make(map[string][]*Instance),
|
|
extServer: srv,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
expiry := time.Now().Add(1 * time.Hour)
|
|
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
|
|
|
|
srv.mu.RLock()
|
|
info := srv.tokens["acc-1"]
|
|
srv.mu.RUnlock()
|
|
|
|
require.NotNil(t, info)
|
|
require.Equal(t, "ya29.full-token", info.AccessToken)
|
|
require.Equal(t, "1//full-refresh", info.RefreshToken)
|
|
require.False(t, info.ExpiresAt.IsZero())
|
|
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
|
|
}
|
|
|
|
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
|
|
csrf := "pool-csrf"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
pool := &Pool{
|
|
config: DefaultConfig(),
|
|
instances: make(map[string][]*Instance),
|
|
extServer: srv,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
available := int32(77)
|
|
minimum := int32(25)
|
|
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
|
|
|
|
srv.mu.RLock()
|
|
info := srv.credits["acc-1"]
|
|
srv.mu.RUnlock()
|
|
|
|
require.NotNil(t, info)
|
|
require.True(t, info.UseAICredits)
|
|
require.NotNil(t, info.AvailableCredits)
|
|
require.Equal(t, available, *info.AvailableCredits)
|
|
require.NotNil(t, info.MinimumCreditAmountForUsage)
|
|
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
|
|
}
|
|
|
|
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
|
|
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
|
|
// Create a mock upstream that records what it receives
|
|
var receivedHeaders http.Header
|
|
var mu sync.Mutex
|
|
fallback := &recordingUpstreamWithCallback{}
|
|
fallback.onDo = func(req *http.Request) {
|
|
mu.Lock()
|
|
receivedHeaders = req.Header.Clone()
|
|
mu.Unlock()
|
|
}
|
|
|
|
csrf := "test-csrf"
|
|
srv, err := NewMockExtensionServer(csrf)
|
|
require.NoError(t, err)
|
|
defer srv.Close()
|
|
|
|
pool := &Pool{
|
|
config: DefaultConfig(),
|
|
instances: make(map[string][]*Instance),
|
|
extServer: srv,
|
|
}
|
|
|
|
upstream := NewLSPoolUpstream(pool, fallback)
|
|
|
|
// Non-streamGenerateContent request → should pass through to fallback
|
|
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
|
|
req.Header.Set("Authorization", "Bearer ya29.test")
|
|
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
|
|
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
|
|
req.Header.Set(useAICreditsHeader, "true")
|
|
req.Header.Set(availableCreditsHeader, "42")
|
|
req.Header.Set(minimumCreditAmountHeader, "50")
|
|
|
|
resp, err := upstream.Do(req, "", 1, 1)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
|
|
// Internal headers should never leak to the direct upstream.
|
|
mu.Lock()
|
|
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
|
|
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
|
|
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
|
|
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
|
|
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
|
|
mu.Unlock()
|
|
|
|
srv.mu.RLock()
|
|
tokenInfo := srv.tokens["1"]
|
|
creditsInfo := srv.credits["1"]
|
|
srv.mu.RUnlock()
|
|
|
|
require.NotNil(t, tokenInfo)
|
|
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
|
|
require.NotNil(t, creditsInfo)
|
|
require.True(t, creditsInfo.UseAICredits)
|
|
require.NotNil(t, creditsInfo.AvailableCredits)
|
|
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
|
|
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
|
|
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
|
|
}
|
|
|
|
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
|
|
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
|
|
body := `{
|
|
"model": "claude-sonnet-4-6",
|
|
"request": {
|
|
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "Hello"}]},
|
|
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
|
{"role": "user", "parts": [{"text": "How are you?"}]}
|
|
]
|
|
}
|
|
}`
|
|
prompt, model := extractPromptAndModel([]byte(body))
|
|
require.Equal(t, "claude-sonnet-4-6", model)
|
|
require.Contains(t, prompt, "You are helpful")
|
|
require.Contains(t, prompt, "Hello")
|
|
require.Contains(t, prompt, "Hi there!")
|
|
require.Contains(t, prompt, "How are you?")
|
|
}
|
|
|
|
// TestExtractUsageFromTrajectory verifies token usage extraction
|
|
func TestExtractUsageFromTrajectory(t *testing.T) {
|
|
resp := `{
|
|
"trajectory": {
|
|
"steps": [{
|
|
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
|
|
"status": "CORTEX_STEP_STATUS_DONE",
|
|
"plannerResponse": {"response": "OK"},
|
|
"metadata": {
|
|
"modelUsage": {
|
|
"inputTokens": "150",
|
|
"outputTokens": "5"
|
|
}
|
|
}
|
|
}]
|
|
}
|
|
}`
|
|
usage := extractUsageFromTrajectory([]byte(resp))
|
|
require.NotNil(t, usage)
|
|
require.Equal(t, 150, usage["promptTokenCount"])
|
|
require.Equal(t, 5, usage["candidatesTokenCount"])
|
|
require.Equal(t, 155, usage["totalTokenCount"])
|
|
}
|
|
|
|
// TestSSEChunkFormat verifies the Gemini SSE output format
|
|
func TestSSEChunkFormat(t *testing.T) {
|
|
chunk := buildGeminiSSEChunk("Hello world")
|
|
require.True(t, len(chunk) > 0)
|
|
require.Contains(t, chunk, "data: ")
|
|
require.Contains(t, chunk, `"text":"Hello world"`)
|
|
require.Contains(t, chunk, `"role":"model"`)
|
|
require.True(t, chunk[len(chunk)-2:] == "\n\n")
|
|
|
|
// Verify it's valid JSON after stripping "data: " prefix
|
|
jsonStr := chunk[len("data: ") : len(chunk)-2]
|
|
var parsed map[string]any
|
|
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
|
require.NoError(t, err)
|
|
response := parsed["response"].(map[string]any)
|
|
candidates := response["candidates"].([]any)
|
|
require.Len(t, candidates, 1)
|
|
}
|
|
|
|
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
|
|
func TestSSEFinalChunkFormat(t *testing.T) {
|
|
usage := map[string]any{
|
|
"promptTokenCount": 100,
|
|
"candidatesTokenCount": 50,
|
|
"totalTokenCount": 150,
|
|
}
|
|
chunk := buildGeminiSSEFinalChunk(usage)
|
|
require.Contains(t, chunk, "data: ")
|
|
require.Contains(t, chunk, `"finishReason":"STOP"`)
|
|
require.Contains(t, chunk, `"usageMetadata"`)
|
|
}
|
|
|
|
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
|
|
var getCalls atomic.Int32
|
|
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
|
|
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
|
|
getCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
|
return
|
|
}
|
|
|
|
http.NotFound(w, r)
|
|
}))
|
|
defer server.Close()
|
|
|
|
inst := &Instance{
|
|
AccountID: "42",
|
|
CSRF: "test-csrf",
|
|
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
client: server.Client(),
|
|
healthy: true,
|
|
lastUsed: time.Now(),
|
|
}
|
|
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
|
|
|
|
pr, pw := io.Pipe()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
|
defer cancel()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
|
|
close(done)
|
|
}()
|
|
|
|
body, err := io.ReadAll(pr)
|
|
require.NoError(t, err)
|
|
<-done
|
|
|
|
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
|
|
require.Contains(t, string(body), "hello from ls")
|
|
}
|
|
|
|
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
|
|
func TestRequestHasToolsEdgeCases(t *testing.T) {
|
|
// null tools
|
|
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
|
|
// tools with empty function declarations
|
|
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
|
|
// deeply nested wrapped format
|
|
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
|
|
}
|
|
|
|
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
|
|
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
|
|
var startCalls atomic.Int32
|
|
var sendCalls atomic.Int32
|
|
var getCalls atomic.Int32
|
|
var sendBodiesMu sync.Mutex
|
|
var sendBodies []map[string]any
|
|
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
|
|
switch {
|
|
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
startCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
sendCalls.Add(1)
|
|
var payload map[string]any
|
|
err := json.NewDecoder(r.Body).Decode(&payload)
|
|
require.NoError(t, err)
|
|
sendBodiesMu.Lock()
|
|
sendBodies = append(sendBodies, payload)
|
|
sendBodiesMu.Unlock()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
|
call := getCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
text := "hello from ls"
|
|
if call > 1 {
|
|
text = "follow up from ls"
|
|
}
|
|
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
inst := &Instance{
|
|
AccountID: "42",
|
|
CSRF: "test-csrf",
|
|
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
client: server.Client(),
|
|
healthy: true,
|
|
lastUsed: time.Now(),
|
|
}
|
|
inst.SetModelMappingReady(true)
|
|
pool := &Pool{
|
|
config: Config{ReplicasPerAccount: 1},
|
|
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
}
|
|
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
|
|
|
|
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
|
require.NoError(t, err)
|
|
req1.Header.Set("Authorization", "Bearer downstream-a")
|
|
|
|
resp1, err := upstream.Do(req1, "", 42, 1)
|
|
require.NoError(t, err)
|
|
body1, err := io.ReadAll(resp1.Body)
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
|
|
|
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
|
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
|
require.NoError(t, err)
|
|
req2.Header.Set("Authorization", "Bearer downstream-a")
|
|
|
|
resp2, err := upstream.Do(req2, "", 42, 1)
|
|
require.NoError(t, err)
|
|
body2, err := io.ReadAll(resp2.Body)
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(body2), `"text":"follow up from ls"`)
|
|
|
|
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
|
|
require.Equal(t, int32(2), sendCalls.Load())
|
|
|
|
sendBodiesMu.Lock()
|
|
require.Len(t, sendBodies, 2)
|
|
firstSend := sendBodies[0]
|
|
sendBodiesMu.Unlock()
|
|
|
|
require.Equal(t, "cid-1", firstSend["cascadeId"])
|
|
require.Equal(t, false, firstSend["blocking"])
|
|
metadata, ok := firstSend["metadata"].(map[string]any)
|
|
require.True(t, ok)
|
|
require.Equal(t, "antigravity", metadata["ideName"])
|
|
require.Equal(t, "1.107.0", metadata["ideVersion"])
|
|
require.NotContains(t, firstSend, "clientType")
|
|
require.NotContains(t, firstSend, "messageOrigin")
|
|
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
|
|
require.True(t, ok)
|
|
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
|
|
require.True(t, ok)
|
|
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
|
require.True(t, ok)
|
|
require.NotEmpty(t, requestedModel["model"])
|
|
require.Len(t, plannerConfig, 1)
|
|
require.Len(t, cascadeConfig, 1)
|
|
}
|
|
|
|
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
|
|
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
|
|
var startCalls atomic.Int32
|
|
var sendCalls atomic.Int32
|
|
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
|
|
|
switch {
|
|
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
startCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
sendCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
inst := &Instance{
|
|
AccountID: "42",
|
|
CSRF: "test-csrf",
|
|
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
client: server.Client(),
|
|
healthy: true,
|
|
lastUsed: time.Now(),
|
|
}
|
|
inst.SetModelMappingReady(true)
|
|
fallback := &recordingUpstream{}
|
|
pool := &Pool{
|
|
config: Config{ReplicasPerAccount: 1},
|
|
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
}
|
|
upstream := NewLSPoolUpstream(pool, fallback)
|
|
|
|
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
|
require.NoError(t, err)
|
|
req1.Header.Set("Authorization", "Bearer downstream-a")
|
|
|
|
resp1, err := upstream.Do(req1, "", 42, 1)
|
|
require.NoError(t, err)
|
|
body1, err := io.ReadAll(resp1.Body)
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
|
|
|
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
|
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
|
require.NoError(t, err)
|
|
req2.Header.Set("Authorization", "Bearer downstream-a")
|
|
|
|
resp2, err := upstream.Do(req2, "", 42, 1)
|
|
require.NoError(t, err)
|
|
body2, err := io.ReadAll(resp2.Body)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, "ok", string(body2))
|
|
require.Equal(t, 1, fallback.doCalls)
|
|
require.Equal(t, int32(1), startCalls.Load())
|
|
require.Equal(t, int32(1), sendCalls.Load())
|
|
}
|
|
|
|
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
|
|
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
|
|
|
var startCalls atomic.Int32
|
|
var sendCalls atomic.Int32
|
|
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
|
startCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
|
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
|
sendCalls.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"queued":false}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
inst := &Instance{
|
|
AccountID: "42",
|
|
CSRF: "test-csrf",
|
|
Address: strings.TrimPrefix(server.URL, "https://"),
|
|
client: server.Client(),
|
|
healthy: true,
|
|
lastUsed: time.Now(),
|
|
}
|
|
|
|
fallback := &recordingUpstream{}
|
|
pool := &Pool{
|
|
config: Config{ReplicasPerAccount: 1},
|
|
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
|
}
|
|
upstream := NewLSPoolUpstream(pool, fallback)
|
|
|
|
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer downstream-a")
|
|
|
|
resp, err := upstream.Do(req, "", 42, 1)
|
|
require.Nil(t, resp)
|
|
require.ErrorIs(t, err, errLSModelMapPending)
|
|
require.Equal(t, int32(0), startCalls.Load())
|
|
require.Equal(t, int32(0), sendCalls.Load())
|
|
require.Equal(t, 0, fallback.doCalls)
|
|
}
|
|
|
|
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
|
|
type recordingUpstreamWithCallback struct {
|
|
recordingUpstream
|
|
onDo func(req *http.Request)
|
|
}
|
|
|
|
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
|
|
if r.onDo != nil {
|
|
r.onDo(req)
|
|
}
|
|
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
|
|
}
|