diff --git a/Dockerfile b/Dockerfile index 3b18fcfd..195a8369 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26.1-alpine +ARG GOLANG_IMAGE=golang:1.25-alpine ARG ALPINE_IMAGE=alpine:3.21 ARG DEBIAN_IMAGE=debian:bookworm-slim ARG POSTGRES_IMAGE=postgres:18-alpine diff --git a/backend/Dockerfile b/backend/Dockerfile index aeb20fdb..4f4bf732 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.7-alpine +FROM golang:1.25-alpine WORKDIR /app diff --git a/backend/cmd/lsworker/main.go b/backend/cmd/lsworker/main.go deleted file mode 100644 index deeb0649..00000000 --- a/backend/cmd/lsworker/main.go +++ /dev/null @@ -1,49 +0,0 @@ -package main - -import ( - "context" - "errors" - "log/slog" - "net/http" - "os" - "os/signal" - "syscall" - - "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" -) - -func main() { - server, err := lspool.NewWorkerServerFromEnv() - if err != nil { - slog.Error("failed to initialize lsworker", "err", err) - os.Exit(1) - } - defer server.Close() - - httpServer := &http.Server{ - Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"), - Handler: server.Handler(), - ReadHeaderTimeout: 10 * 1e9, - } - - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - go func() { - <-ctx.Done() - _ = httpServer.Shutdown(context.Background()) - }() - - slog.Info("lsworker listening", "addr", httpServer.Addr) - if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - slog.Error("lsworker exited with error", "err", err) - os.Exit(1) - } -} - -func envOrDefault(key, fallback string) string { - if value := os.Getenv(key); value != "" { - return value - } - return fallback -} diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index b64b633b..16fb6bd6 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -53,9 +53,8 @@ const ( // defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0 var defaultUserAgentVersion = "1.107.0" - -// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 -var defaultClientSecret string +// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖 +var defaultClientSecret = "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" func init() { // 从环境变量读取版本号,未设置则使用默认值 diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 58bc5889..081dece4 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -1,6 +1,8 @@ // Package claude provides constants and helpers for Claude API integration. package claude +import "strings" + // Claude Code 客户端相关常量 // DefaultCLIVersion 是当前模拟的 Claude CLI 版本 @@ -30,32 +32,64 @@ const ( // 这些 token 是客户端特有的,不应透传给上游 API。 var DroppedBetas = []string{} -// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header -const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort +// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header(OAuth 账号,不含 context-1m) +// 使用 GetOAuthBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。 +const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort -// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header +// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header(OAuth,不含 context-1m) // // NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic" // Claude Code for non-Claude-Code clients, we must include the claude-code beta // even if the request doesn't use tools, otherwise upstream may reject the // request as a non-Claude-Code API request. -const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort +const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort -// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header -const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort +// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header(OAuth,不含 context-1m) +const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort // CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement -// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) +// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(OAuth,不含 claude-code / context-1m) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort -// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) -const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope +// APIKeyBetaHeader API-key 账号使用的 anthropic-beta header(不含 oauth / context-1m) +// 使用 GetAPIKeyBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。 +const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaEffort + "," + BetaPromptCachingScope -// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) +// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不含 oauth / claude-code) const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort +// ModelSupports1M 判断模型是否支持 1M context window。 +// 与 claude-code-2.1.88 bundle 中 modelSupports1M 逻辑保持一致: +// +// claude-sonnet-4 系列 和 claude-opus-4-6 支持 1M context。 +func ModelSupports1M(modelID string) bool { + lower := strings.ToLower(strings.TrimSpace(modelID)) + return strings.Contains(lower, "claude-sonnet-4") || strings.Contains(lower, "opus-4-6") +} + +// GetOAuthBetaHeader 返回 OAuth 账号的 beta header。 +// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。 +func GetOAuthBetaHeader(modelID string) string { + if ModelSupports1M(modelID) { + return BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort + } + return DefaultBetaHeader +} + +// GetAPIKeyBetaHeader 返回 API-key 账号的 beta header。 +// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。 +func GetAPIKeyBetaHeader(modelID string) string { + if strings.Contains(strings.ToLower(modelID), "haiku") { + return APIKeyHaikuBetaHeader + } + if ModelSupports1M(modelID) { + return BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope + } + return APIKeyBetaHeader +} + // DefaultHeaders 是 Claude Code 客户端默认请求头。 var DefaultHeaders = map[string]string{ // Keep these in sync with recent Claude CLI traffic to reduce the chance @@ -70,7 +104,7 @@ var DefaultHeaders = map[string]string{ "X-Stainless-Retry-Count": "0", "X-Stainless-Timeout": "600", "X-App": "cli", - "Anthropic-Dangerous-Direct-Browser-Access": "true", + "anthropic-version": "2023-06-01", } // ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值) diff --git a/backend/internal/pkg/lspool/backend.go b/backend/internal/pkg/lspool/backend.go deleted file mode 100644 index ba77e3d7..00000000 --- a/backend/internal/pkg/lspool/backend.go +++ /dev/null @@ -1,13 +0,0 @@ -package lspool - -import "time" - -// Backend is the control-plane abstraction used by the HTTP upstream wrapper. -// It may be backed by a local in-process Pool or by remote LS workers. -type Backend interface { - GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) - SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) - SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) - Stats() map[string]any - Close() -} diff --git a/backend/internal/pkg/lspool/global.go b/backend/internal/pkg/lspool/global.go deleted file mode 100644 index 926f69d7..00000000 --- a/backend/internal/pkg/lspool/global.go +++ /dev/null @@ -1,94 +0,0 @@ -// Package lspool provides LS-mode integration for the antigravity gateway. -// -// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to -// streamGenerateContent are routed through a real Language Server instance -// instead of directly to cloudcode-pa. This provides: -// -// - Authentic TLS fingerprint (Google's own Go binary) -// - Real session management and Heartbeat -// - Indistinguishable from a real IDE instance -// -// To enable: set environment variable ANTIGRAVITY_LS_MODE=true -// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path -package lspool - -import ( - "log/slog" - "os" - "strings" - "sync" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -var ( - globalBackend Backend - globalPoolOnce sync.Once - lsModeEnabled bool -) - -func init() { - lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true" -} - -// IsLSModeEnabled returns whether LS mode is active -func IsLSModeEnabled() bool { - return lsModeEnabled -} - -const ( - LSStrategyDirect = "direct" - LSStrategyJSParity = "js-parity" -) - -// CurrentLSStrategy returns the active LS routing strategy. -// Unknown values are treated as "direct" for safety. -func CurrentLSStrategy() string { - switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) { - case "", LSStrategyDirect: - return LSStrategyDirect - case LSStrategyJSParity: - return LSStrategyJSParity - default: - return LSStrategyDirect - } -} - -// GlobalPool returns the singleton LS pool instance -// Creates it on first call if LS mode is enabled -func GlobalPool(cfg *config.Config) Backend { - if !lsModeEnabled { - return nil - } - globalPoolOnce.Do(func() { - manager, err := NewWorkerManagerFromConfig(cfg) - if err != nil { - slog.Default().Error("failed to initialize LS worker manager", "err", err) - return - } - globalBackend = manager - }) - return globalBackend -} - -// Shutdown closes the global pool -func Shutdown() { - if globalBackend != nil { - globalBackend.Close() - } -} - -// StatusInfo returns the current LS pool status for diagnostics -func StatusInfo() map[string]any { - info := map[string]any{ - "ls_mode_enabled": lsModeEnabled, - "build": "enhanced", - "user_agent": "antigravity/1.107.0", - } - if lsModeEnabled && globalBackend != nil { - stats := globalBackend.Stats() - info["pool_total"] = stats["total"] - info["pool_active"] = stats["active"] - } - return info -} diff --git a/backend/internal/pkg/lspool/integration_test.go b/backend/internal/pkg/lspool/integration_test.go deleted file mode 100644 index 45812bc8..00000000 --- a/backend/internal/pkg/lspool/integration_test.go +++ /dev/null @@ -1,864 +0,0 @@ -package lspool - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/binary" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func readConnectFrame(r io.Reader) ([]byte, error) { - header := make([]byte, 5) - if _, err := io.ReadFull(r, header); err != nil { - return nil, err - } - payloadLen := binary.BigEndian.Uint32(header[1:5]) - payload := make([]byte, payloadLen) - if _, err := io.ReadFull(r, payload); err != nil { - return nil, err - } - return payload, nil -} - -func decodeProtoBytesField(data []byte, targetField int) []byte { - i := 0 - for i < len(data) { - tag, n := binary.Uvarint(data[i:]) - if n <= 0 { - return nil - } - i += n - fieldNum := int(tag >> 3) - wireType := tag & 0x7 - switch wireType { - case 0: - _, n = binary.Uvarint(data[i:]) - if n <= 0 { - return nil - } - i += n - case 2: - length, n := binary.Uvarint(data[i:]) - if n <= 0 { - return nil - } - i += n - if i+int(length) > len(data) { - return nil - } - if fieldNum == targetField { - return data[i : i+int(length)] - } - i += int(length) - case 1: - i += 8 - case 5: - i += 4 - default: - return nil - } - } - return nil -} - -func decodeProtoBytesFields(data []byte, targetField int) [][]byte { - var values [][]byte - i := 0 - for i < len(data) { - tag, n := binary.Uvarint(data[i:]) - if n <= 0 { - return values - } - i += n - fieldNum := int(tag >> 3) - wireType := tag & 0x7 - switch wireType { - case 0: - _, n = binary.Uvarint(data[i:]) - if n <= 0 { - return values - } - i += n - case 2: - length, n := binary.Uvarint(data[i:]) - if n <= 0 { - return values - } - i += n - if i+int(length) > len(data) { - return values - } - if fieldNum == targetField { - values = append(values, append([]byte(nil), data[i:i+int(length)]...)) - } - i += int(length) - case 1: - i += 8 - case 5: - i += 4 - default: - return values - } - } - return values -} - -func decodeTopicRows(topic []byte) map[string]string { - rows := make(map[string]string) - for _, entry := range decodeProtoBytesFields(topic, 1) { - key := decodeProtoString(entry, 1) - row := decodeProtoBytesField(entry, 2) - rows[key] = decodeProtoString(row, 1) - } - return rows -} - -func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) { - t.Helper() - decoded, err := base64.StdEncoding.DecodeString(got) - require.NoError(t, err) - require.Equal(t, want, decoded) -} - -// TestMockExtensionServerTokenInjection verifies the token injection flow: -// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo -func TestMockExtensionServerTokenInjection(t *testing.T) { - csrf := "test-csrf-token" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - // 1. Set token for an account - srv.SetToken("account-1", &TokenInfo{ - AccessToken: "ya29.test-access-token", - RefreshToken: "1//test-refresh-token", - ExpiresAt: time.Now().Add(1 * time.Hour), - }) - - // 2. Verify token is stored - srv.mu.RLock() - info, ok := srv.tokens["account-1"] - srv.mu.RUnlock() - require.True(t, ok) - require.Equal(t, "ya29.test-access-token", info.AccessToken) - require.Equal(t, "1//test-refresh-token", info.RefreshToken) - require.False(t, info.ExpiresAt.IsZero()) - - // 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server) - req, _ := http.NewRequest("POST", - fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), - bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth")))) - req.Header.Set("x-codeium-csrf-token", csrf) - req.Header.Set("Content-Type", "application/connect+proto") - - // The stream handler will block, so run in background and cancel after we confirm connection - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - req = req.WithContext(ctx) - - client := &http.Client{} - resp, err := client.Do(req) - if err == nil { - defer resp.Body.Close() - require.Equal(t, 200, resp.StatusCode) - require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type")) - - // Read the first envelope frame (initial state) - header := make([]byte, 5) - n, readErr := resp.Body.Read(header) - if readErr == nil && n == 5 { - require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)") - t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5]) - } - } -} - -// TestMockExtensionServerCSRF verifies CSRF token validation -func TestMockExtensionServerCSRF(t *testing.T) { - csrf := "correct-csrf" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port()) - - // Wrong CSRF → 403 - req, _ := http.NewRequest("POST", base, nil) - req.Header.Set("x-codeium-csrf-token", "wrong-csrf") - req.Header.Set("Content-Type", "application/proto") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, 403, resp.StatusCode) - - // Correct CSRF → 200 - req2, _ := http.NewRequest("POST", base, nil) - req2.Header.Set("x-codeium-csrf-token", csrf) - req2.Header.Set("Content-Type", "application/proto") - resp2, err := http.DefaultClient.Do(req2) - require.NoError(t, err) - defer resp2.Body.Close() - require.Equal(t, 200, resp2.StatusCode) -} - -// TestMockExtensionServerGetSecretValue verifies the fallback token path -func TestMockExtensionServerGetSecretValue(t *testing.T) { - csrf := "test-csrf" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"}) - - // GetSecretValue should return the token - req, _ := http.NewRequest("POST", - fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()), - nil) - req.Header.Set("x-codeium-csrf-token", csrf) - req.Header.Set("Content-Type", "application/proto") - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, 200, resp.StatusCode) -} - -// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format -func TestOAuthTokenInfoProto(t *testing.T) { - expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC) - bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry) - - // Verify fields are present by checking proto wire format - require.True(t, len(bin) > 0, "proto should not be empty") - - // Field 1 (access_token): tag=0x0a, value="ya29.test" - require.Contains(t, string(bin), "ya29.test") - // Field 2 (token_type): tag=0x12, value="Bearer" - require.Contains(t, string(bin), "Bearer") - // Field 3 (refresh_token): tag=0x1a, value="1//refresh" - require.Contains(t, string(bin), "1//refresh") - - // Without refresh_token - binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry) - require.NotContains(t, string(binNoRefresh), "1//refresh") -} - -// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded -func TestOAuthTokenInfoWithRealExpiry(t *testing.T) { - future := time.Now().Add(2 * time.Hour) - bin := buildOAuthTokenInfoBinary("token", "refresh", future) - - // Zero expiry should default to ~1h - binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{}) - - // They should be different lengths or content (different expiry timestamps) - // Both should be valid (non-empty) - require.True(t, len(bin) > 0) - require.True(t, len(binZero) > 0) -} - -// TestUSSTopicWithOAuth verifies the full USS topic proto structure -func TestUSSTopicWithOAuth(t *testing.T) { - expiry := time.Now().Add(1 * time.Hour) - topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry) - - require.True(t, len(topic) > 0) - // The topic should contain the sentinel key - require.Contains(t, string(topic), "oauthTokenInfoSentinelKey") -} - -func TestUSSTopicWithModelCredits(t *testing.T) { - available := int32(123) - minimum := int32(50) - topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{ - UseAICredits: true, - AvailableCredits: &available, - MinimumCreditAmountForUsage: &minimum, - }) - - require.True(t, len(topic) > 0) - require.Contains(t, string(topic), useAICreditsSentinelKey) - require.Contains(t, string(topic), availableCreditsSentinelKey) - require.Contains(t, string(topic), minimumCreditAmountForUsageKey) - - rows := decodeTopicRows(topic) - requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true)) - requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available)) - requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum)) -} - -func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) { - csrf := "test-csrf-token" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - srv.SetModelCredits("account-1", &ModelCreditsInfo{}) - - req, _ := http.NewRequest("POST", - fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), - bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits")))) - req.Header.Set("x-codeium-csrf-token", csrf) - req.Header.Set("Content-Type", "application/connect+proto") - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - req = req.WithContext(ctx) - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, 200, resp.StatusCode) - - // Drain the initial_state frame first. - _, err = readConnectFrame(resp.Body) - require.NoError(t, err) - - available := int32(77) - minimum := int32(25) - srv.SetModelCredits("account-1", &ModelCreditsInfo{ - UseAICredits: true, - AvailableCredits: &available, - MinimumCreditAmountForUsage: &minimum, - }) - - values := make(map[string]string, 3) - for len(values) < 3 { - frame, readErr := readConnectFrame(resp.Body) - require.NoError(t, readErr) - applied := decodeProtoBytesField(frame, 2) - require.NotEmpty(t, applied) - key := decodeProtoString(applied, 1) - row := decodeProtoBytesField(applied, 2) - values[key] = decodeProtoString(row, 1) - } - - require.Contains(t, values, useAICreditsSentinelKey) - require.Contains(t, values, availableCreditsSentinelKey) - require.Contains(t, values, minimumCreditAmountForUsageKey) - requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true)) - requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available)) - requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum)) -} - -// TestBuildInitialStateUpdate verifies the USS update wrapper -func TestBuildInitialStateUpdate(t *testing.T) { - topicData := buildEmptyTopic() - update := buildInitialStateUpdate(topicData) - // Should be a valid proto bytes field (field 1 = initial_state) - require.True(t, len(update) >= 0) // empty topic is valid - - topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour)) - update2 := buildInitialStateUpdate(topicData2) - require.True(t, len(update2) > len(update), "non-empty topic should produce larger update") -} - -// TestPoolSetAccountTokenComplete verifies pool accepts full credential set -func TestPoolSetAccountTokenComplete(t *testing.T) { - csrf := "pool-csrf" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - pool := &Pool{ - config: DefaultConfig(), - instances: make(map[string][]*Instance), - extServer: srv, - ctx: ctx, - cancel: cancel, - } - - expiry := time.Now().Add(1 * time.Hour) - pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry) - - srv.mu.RLock() - info := srv.tokens["acc-1"] - srv.mu.RUnlock() - - require.NotNil(t, info) - require.Equal(t, "ya29.full-token", info.AccessToken) - require.Equal(t, "1//full-refresh", info.RefreshToken) - require.False(t, info.ExpiresAt.IsZero()) - require.WithinDuration(t, expiry, info.ExpiresAt, time.Second) -} - -func TestPoolSetAccountModelCreditsComplete(t *testing.T) { - csrf := "pool-csrf" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - pool := &Pool{ - config: DefaultConfig(), - instances: make(map[string][]*Instance), - extServer: srv, - ctx: ctx, - cancel: cancel, - } - - available := int32(77) - minimum := int32(25) - pool.SetAccountModelCredits("acc-1", true, &available, &minimum) - - srv.mu.RLock() - info := srv.credits["acc-1"] - srv.mu.RUnlock() - - require.NotNil(t, info) - require.True(t, info.UseAICredits) - require.NotNil(t, info.AvailableCredits) - require.Equal(t, available, *info.AvailableCredits) - require.NotNil(t, info.MinimumCreditAmountForUsage) - require.Equal(t, minimum, *info.MinimumCreditAmountForUsage) -} - -// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped. -func TestUpstreamAdapterExtractsCredentials(t *testing.T) { - // Create a mock upstream that records what it receives - var receivedHeaders http.Header - var mu sync.Mutex - fallback := &recordingUpstreamWithCallback{} - fallback.onDo = func(req *http.Request) { - mu.Lock() - receivedHeaders = req.Header.Clone() - mu.Unlock() - } - - csrf := "test-csrf" - srv, err := NewMockExtensionServer(csrf) - require.NoError(t, err) - defer srv.Close() - - pool := &Pool{ - config: DefaultConfig(), - instances: make(map[string][]*Instance), - extServer: srv, - } - - upstream := NewLSPoolUpstream(pool, fallback) - - // Non-streamGenerateContent request → should pass through to fallback - req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil) - req.Header.Set("Authorization", "Bearer ya29.test") - req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh") - req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z") - req.Header.Set(useAICreditsHeader, "true") - req.Header.Set(availableCreditsHeader, "42") - req.Header.Set(minimumCreditAmountHeader, "50") - - resp, err := upstream.Do(req, "", 1, 1) - require.NoError(t, err) - require.NotNil(t, resp) - - // Internal headers should never leak to the direct upstream. - mu.Lock() - require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token")) - require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry")) - require.Empty(t, receivedHeaders.Get(useAICreditsHeader)) - require.Empty(t, receivedHeaders.Get(availableCreditsHeader)) - require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader)) - mu.Unlock() - - srv.mu.RLock() - tokenInfo := srv.tokens["1"] - creditsInfo := srv.credits["1"] - srv.mu.RUnlock() - - require.NotNil(t, tokenInfo) - require.Equal(t, "ya29.test", tokenInfo.AccessToken) - require.NotNil(t, creditsInfo) - require.True(t, creditsInfo.UseAICredits) - require.NotNil(t, creditsInfo.AvailableCredits) - require.Equal(t, int32(42), *creditsInfo.AvailableCredits) - require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage) - require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage) -} - -// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction -func TestExtractPromptAndModelMultiTurn(t *testing.T) { - body := `{ - "model": "claude-sonnet-4-6", - "request": { - "systemInstruction": {"parts": [{"text": "You are helpful"}]}, - "contents": [ - {"role": "user", "parts": [{"text": "Hello"}]}, - {"role": "model", "parts": [{"text": "Hi there!"}]}, - {"role": "user", "parts": [{"text": "How are you?"}]} - ] - } - }` - prompt, model := extractPromptAndModel([]byte(body)) - require.Equal(t, "claude-sonnet-4-6", model) - require.Contains(t, prompt, "You are helpful") - require.Contains(t, prompt, "Hello") - require.Contains(t, prompt, "Hi there!") - require.Contains(t, prompt, "How are you?") -} - -// TestExtractUsageFromTrajectory verifies token usage extraction -func TestExtractUsageFromTrajectory(t *testing.T) { - resp := `{ - "trajectory": { - "steps": [{ - "type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE", - "status": "CORTEX_STEP_STATUS_DONE", - "plannerResponse": {"response": "OK"}, - "metadata": { - "modelUsage": { - "inputTokens": "150", - "outputTokens": "5" - } - } - }] - } - }` - usage := extractUsageFromTrajectory([]byte(resp)) - require.NotNil(t, usage) - require.Equal(t, 150, usage["promptTokenCount"]) - require.Equal(t, 5, usage["candidatesTokenCount"]) - require.Equal(t, 155, usage["totalTokenCount"]) -} - -// TestSSEChunkFormat verifies the Gemini SSE output format -func TestSSEChunkFormat(t *testing.T) { - chunk := buildGeminiSSEChunk("Hello world") - require.True(t, len(chunk) > 0) - require.Contains(t, chunk, "data: ") - require.Contains(t, chunk, `"text":"Hello world"`) - require.Contains(t, chunk, `"role":"model"`) - require.True(t, chunk[len(chunk)-2:] == "\n\n") - - // Verify it's valid JSON after stripping "data: " prefix - jsonStr := chunk[len("data: ") : len(chunk)-2] - var parsed map[string]any - err := json.Unmarshal([]byte(jsonStr), &parsed) - require.NoError(t, err) - response := parsed["response"].(map[string]any) - candidates := response["candidates"].([]any) - require.Len(t, candidates, 1) -} - -// TestSSEFinalChunkFormat verifies the final SSE chunk with usage -func TestSSEFinalChunkFormat(t *testing.T) { - usage := map[string]any{ - "promptTokenCount": 100, - "candidatesTokenCount": 50, - "totalTokenCount": 150, - } - chunk := buildGeminiSSEFinalChunk(usage) - require.Contains(t, chunk, "data: ") - require.Contains(t, chunk, `"finishReason":"STOP"`) - require.Contains(t, chunk, `"usageMetadata"`) -} - -func TestStreamCascadeResponsePollsImmediately(t *testing.T) { - var getCalls atomic.Int32 - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) - - if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") { - getCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) - return - } - - http.NotFound(w, r) - })) - defer server.Close() - - inst := &Instance{ - AccountID: "42", - CSRF: "test-csrf", - Address: strings.TrimPrefix(server.URL, "https://"), - client: server.Client(), - healthy: true, - lastUsed: time.Now(), - } - upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{}) - - pr, pw := io.Pipe() - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - done := make(chan struct{}) - go func() { - upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil) - close(done) - }() - - body, err := io.ReadAll(pr) - require.NoError(t, err) - <-done - - require.GreaterOrEqual(t, getCalls.Load(), int32(1)) - require.Contains(t, string(body), "hello from ls") -} - -// TestRequestHasToolsEdgeCases verifies tool detection edge cases -func TestRequestHasToolsEdgeCases(t *testing.T) { - // null tools - require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`))) - // tools with empty function declarations - require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`))) - // deeply nested wrapped format - require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`))) -} - -func TestJSParityRouteReusesCascadeSession(t *testing.T) { - t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) - - var startCalls atomic.Int32 - var sendCalls atomic.Int32 - var getCalls atomic.Int32 - var sendBodiesMu sync.Mutex - var sendBodies []map[string]any - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) - - switch { - case strings.HasSuffix(r.URL.Path, "/StartCascade"): - startCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) - case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): - sendCalls.Add(1) - var payload map[string]any - err := json.NewDecoder(r.Body).Decode(&payload) - require.NoError(t, err) - sendBodiesMu.Lock() - sendBodies = append(sendBodies, payload) - sendBodiesMu.Unlock() - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"queued":false}`)) - case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): - call := getCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - text := "hello from ls" - if call > 1 { - text = "follow up from ls" - } - _, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text))) - default: - http.NotFound(w, r) - } - })) - defer server.Close() - - inst := &Instance{ - AccountID: "42", - CSRF: "test-csrf", - Address: strings.TrimPrefix(server.URL, "https://"), - client: server.Client(), - healthy: true, - lastUsed: time.Now(), - } - inst.SetModelMappingReady(true) - pool := &Pool{ - config: Config{ReplicasPerAccount: 1}, - instances: map[string][]*Instance{"42": []*Instance{inst}}, - } - upstream := NewLSPoolUpstream(pool, &recordingUpstream{}) - - req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) - req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) - require.NoError(t, err) - req1.Header.Set("Authorization", "Bearer downstream-a") - - resp1, err := upstream.Do(req1, "", 42, 1) - require.NoError(t, err) - body1, err := io.ReadAll(resp1.Body) - require.NoError(t, err) - require.Contains(t, string(body1), `"text":"hello from ls"`) - - req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) - req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) - require.NoError(t, err) - req2.Header.Set("Authorization", "Bearer downstream-a") - - resp2, err := upstream.Do(req2, "", 42, 1) - require.NoError(t, err) - body2, err := io.ReadAll(resp2.Body) - require.NoError(t, err) - require.Contains(t, string(body2), `"text":"follow up from ls"`) - - require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript") - require.Equal(t, int32(2), sendCalls.Load()) - - sendBodiesMu.Lock() - require.Len(t, sendBodies, 2) - firstSend := sendBodies[0] - sendBodiesMu.Unlock() - - require.Equal(t, "cid-1", firstSend["cascadeId"]) - require.Equal(t, false, firstSend["blocking"]) - metadata, ok := firstSend["metadata"].(map[string]any) - require.True(t, ok) - require.Equal(t, "antigravity", metadata["ideName"]) - require.Equal(t, "1.107.0", metadata["ideVersion"]) - require.NotContains(t, firstSend, "clientType") - require.NotContains(t, firstSend, "messageOrigin") - cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any) - require.True(t, ok) - plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any) - require.True(t, ok) - requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) - require.True(t, ok) - require.NotEmpty(t, requestedModel["model"]) - require.Len(t, plannerConfig, 1) - require.Len(t, cascadeConfig, 1) -} - -func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) { - t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) - - var startCalls atomic.Int32 - var sendCalls atomic.Int32 - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) - - switch { - case strings.HasSuffix(r.URL.Path, "/StartCascade"): - startCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) - case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): - sendCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"queued":false}`)) - case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) - default: - http.NotFound(w, r) - } - })) - defer server.Close() - - inst := &Instance{ - AccountID: "42", - CSRF: "test-csrf", - Address: strings.TrimPrefix(server.URL, "https://"), - client: server.Client(), - healthy: true, - lastUsed: time.Now(), - } - inst.SetModelMappingReady(true) - fallback := &recordingUpstream{} - pool := &Pool{ - config: Config{ReplicasPerAccount: 1}, - instances: map[string][]*Instance{"42": []*Instance{inst}}, - } - upstream := NewLSPoolUpstream(pool, fallback) - - req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) - req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) - require.NoError(t, err) - req1.Header.Set("Authorization", "Bearer downstream-a") - - resp1, err := upstream.Do(req1, "", 42, 1) - require.NoError(t, err) - body1, err := io.ReadAll(resp1.Body) - require.NoError(t, err) - require.Contains(t, string(body1), `"text":"hello from ls"`) - - req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) - req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) - require.NoError(t, err) - req2.Header.Set("Authorization", "Bearer downstream-a") - - resp2, err := upstream.Do(req2, "", 42, 1) - require.NoError(t, err) - body2, err := io.ReadAll(resp2.Body) - require.NoError(t, err) - - require.Equal(t, "ok", string(body2)) - require.Equal(t, 1, fallback.doCalls) - require.Equal(t, int32(1), startCalls.Load()) - require.Equal(t, int32(1), sendCalls.Load()) -} - -func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) { - t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) - - var startCalls atomic.Int32 - var sendCalls atomic.Int32 - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case strings.HasSuffix(r.URL.Path, "/StartCascade"): - startCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) - case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): - sendCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"queued":false}`)) - default: - http.NotFound(w, r) - } - })) - defer server.Close() - - inst := &Instance{ - AccountID: "42", - CSRF: "test-csrf", - Address: strings.TrimPrefix(server.URL, "https://"), - client: server.Client(), - healthy: true, - lastUsed: time.Now(), - } - - fallback := &recordingUpstream{} - pool := &Pool{ - config: Config{ReplicasPerAccount: 1}, - instances: map[string][]*Instance{"42": []*Instance{inst}}, - } - upstream := NewLSPoolUpstream(pool, fallback) - - reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) - req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody)) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer downstream-a") - - resp, err := upstream.Do(req, "", 42, 1) - require.Nil(t, resp) - require.ErrorIs(t, err, errLSModelMapPending) - require.Equal(t, int32(0), startCalls.Load()) - require.Equal(t, int32(0), sendCalls.Load()) - require.Equal(t, 0, fallback.doCalls) -} - -// recordingUpstreamWithCallback extends the base recordingUpstream with a callback -type recordingUpstreamWithCallback struct { - recordingUpstream - onDo func(req *http.Request) -} - -func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) { - if r.onDo != nil { - r.onDo(req) - } - return r.recordingUpstream.Do(req, proxyURL, accountID, c) -} diff --git a/backend/internal/pkg/lspool/mock_extension_server.go b/backend/internal/pkg/lspool/mock_extension_server.go deleted file mode 100644 index e026be86..00000000 --- a/backend/internal/pkg/lspool/mock_extension_server.go +++ /dev/null @@ -1,920 +0,0 @@ -// Package lspool provides a mock Extension Server that the LS binary connects -// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server -// using connectNodeAdapter. We replicate that protocol here. -// -// Protocol details (from extension.js source): -// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS) -// - Auth: x-codeium-csrf-token header on every request -// - Unary request Content-Type: application/proto (binary protobuf, no envelope) -// OR application/connect+proto (with 5-byte envelope) -// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope) -// - Stream request Content-Type: application/connect+proto (with 5-byte envelope) -// - Stream response Content-Type: application/connect+proto (envelope-framed messages) -// -// The LS sends requests with content-type "application/connect+proto" for BOTH -// unary and streaming RPCs. ConnectRPC's content-type regex: -// -// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i -// -// If "connect+" prefix is present → stream mode; otherwise → unary mode. -// However the LS Go client uses the Connect protocol client which always sends -// "application/proto" for unary and "application/connect+proto" for streaming. -// -// We detect the RPC kind from the URL path and respond accordingly. -package lspool - -import ( - "encoding/base64" - "encoding/binary" - "fmt" - "io" - "log/slog" - "net" - "net/http" - "strings" - "sync" - "time" - - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" -) - -// ============================================================ -// Proto helpers — hand-encode minimal proto messages so we don't -// need to import the full generated proto package. -// ============================================================ - -// encodeProtoString writes a proto string field (wire type 2) to a byte slice. -func encodeProtoString(fieldNum int, val string) []byte { - tag := encodeVarint(uint64(fieldNum<<3 | 2)) - length := encodeVarint(uint64(len(val))) - out := make([]byte, 0, len(tag)+len(length)+len(val)) - out = append(out, tag...) - out = append(out, length...) - out = append(out, []byte(val)...) - return out -} - -// encodeProtoBytes writes a proto bytes/message field (wire type 2). -func encodeProtoBytes(fieldNum int, val []byte) []byte { - tag := encodeVarint(uint64(fieldNum<<3 | 2)) - length := encodeVarint(uint64(len(val))) - out := make([]byte, 0, len(tag)+len(length)+len(val)) - out = append(out, tag...) - out = append(out, length...) - out = append(out, val...) - return out -} - -// encodeProtoVarint writes a proto varint field (wire type 0). -func encodeProtoVarint(fieldNum int, val uint64) []byte { - tag := encodeVarint(uint64(fieldNum<<3 | 0)) - v := encodeVarint(val) - out := make([]byte, 0, len(tag)+len(v)) - out = append(out, tag...) - out = append(out, v...) - return out -} - -// encodeProtoBool writes a proto bool field. -func encodeProtoBool(fieldNum int, val bool) []byte { - v := uint64(0) - if val { - v = 1 - } - return encodeProtoVarint(fieldNum, v) -} - -func encodeVarint(v uint64) []byte { - buf := make([]byte, binary.MaxVarintLen64) - n := binary.PutUvarint(buf, v) - return buf[:n] -} - -// decodeProtoString extracts a string field from raw proto bytes. -func decodeProtoString(data []byte, targetField int) string { - i := 0 - for i < len(data) { - if i >= len(data) { - break - } - tag, n := binary.Uvarint(data[i:]) - if n <= 0 { - break - } - i += n - fieldNum := int(tag >> 3) - wireType := tag & 0x7 - - switch wireType { - case 0: // varint - _, n = binary.Uvarint(data[i:]) - if n <= 0 { - return "" - } - i += n - case 2: // length-delimited - length, n := binary.Uvarint(data[i:]) - if n <= 0 { - return "" - } - i += n - if fieldNum == targetField { - end := i + int(length) - if end > len(data) { - return "" - } - return string(data[i:end]) - } - i += int(length) - case 1: // 64-bit - i += 8 - case 5: // 32-bit - i += 4 - default: - return "" - } - } - return "" -} - -// ============================================================ -// ConnectRPC envelope helpers -// ============================================================ - -// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope: -// 1 byte flags + 4 byte big-endian length + payload -func connectEnvelope(flags byte, payload []byte) []byte { - frame := make([]byte, 5+len(payload)) - frame[0] = flags - binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload))) - copy(frame[5:], payload) - return frame -} - -// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC. -// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata. -func connectEndOfStream() []byte { - trailer := []byte("{}") - return connectEnvelope(0x02, trailer) -} - -// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message. -// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is. -func unwrapConnectEnvelope(body []byte) []byte { - if len(body) < 5 { - return body - } - // Check if it looks like an envelope: first byte should be 0x00 or 0x01 - if body[0] > 0x02 { - return body // Not envelope-framed, return raw - } - plen := binary.BigEndian.Uint32(body[1:5]) - if int(plen)+5 > len(body) { - return body // Length mismatch, return raw - } - return body[5 : 5+plen] -} - -// ============================================================ -// OAuthTokenInfo proto builder -// ============================================================ - -// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto. -// -// message OAuthTokenInfo { -// string access_token = 1; -// string token_type = 2; -// string refresh_token = 3; -// google.protobuf.Timestamp expiry = 4; -// bool is_gcp_tos = 6; -// } -func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte { - var buf []byte - buf = append(buf, encodeProtoString(1, accessToken)...) - buf = append(buf, encodeProtoString(2, "Bearer")...) - if refreshToken != "" { - buf = append(buf, encodeProtoString(3, refreshToken)...) - } - // Use real expiry if provided, otherwise default to 1 hour from now - expiry := expiresAt - if expiry.IsZero() { - expiry = time.Now().Add(1 * time.Hour) - } - ts := ×tamppb.Timestamp{ - Seconds: expiry.Unix(), - } - tsBytes, _ := proto.Marshal(ts) - buf = append(buf, encodeProtoBytes(4, tsBytes)...) - buf = append(buf, encodeProtoBool(6, true)...) - return buf -} - -// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token. -// -// message Topic { map data = 1; } -// message Row { string value = 1; int64 e_tag = 2; } -// -// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is -// base64(toBinary(OAuthTokenInfo)). -func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte { - tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt) - tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) - - // Row: value=tokenB64 (field 1), e_tag=1 (field 2) - var row []byte - row = append(row, encodeProtoString(1, tokenB64)...) - row = append(row, encodeProtoVarint(2, 1)...) - - // Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2) - var entry []byte - entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...) - entry = append(entry, encodeProtoBytes(2, row)...) - - // Topic: data map entries use field 1 - var topic []byte - topic = append(topic, encodeProtoBytes(1, entry)...) - - return topic -} - -func buildPrimitiveBoolBinary(val bool) []byte { - // Primitive.bool_value is field 13 in the proto definition - return encodeProtoBool(13, val) -} - -func buildPrimitiveInt32Binary(val int32) []byte { - // Primitive.int32_value is field 3 in the proto definition - return encodeProtoVarint(3, uint64(uint32(val))) -} - -func encodeUSSBinaryValue(value []byte) string { - return base64.StdEncoding.EncodeToString(value) -} - -func encodeUSSPrimitiveBoolValue(val bool) string { - return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val)) -} - -func encodeUSSPrimitiveInt32Value(val int32) string { - return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val)) -} - -func buildUSSTopicRow(key string, value string) []byte { - row := buildUSSRowBinary(value) - - var entry []byte - entry = append(entry, encodeProtoString(1, key)...) - entry = append(entry, encodeProtoBytes(2, row)...) - return entry -} - -func buildUSSRowBinary(value string) []byte { - var row []byte - row = append(row, encodeProtoString(1, value)...) - row = append(row, encodeProtoVarint(2, 1)...) - return row -} - -func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte { - if info == nil { - info = &ModelCreditsInfo{} - } - - minimum := defaultMinimumCreditAmountForUsage - if info.MinimumCreditAmountForUsage != nil { - minimum = *info.MinimumCreditAmountForUsage - } - - entries := make([][]byte, 0, 3) - entries = append(entries, buildUSSTopicRow( - useAICreditsSentinelKey, - encodeUSSPrimitiveBoolValue(info.UseAICredits), - )) - // JS protocol: useAICreditsSentinelKey carries the toggle state. - // availableCreditsSentinelKey is only present when credits are enabled. - if info.UseAICredits { - credits := int32(9999) - if info.AvailableCredits != nil { - credits = *info.AvailableCredits - } - entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits))) - } - entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum))) - - var topic []byte - for _, entry := range entries { - topic = append(topic, encodeProtoBytes(1, entry)...) - } - return topic -} - -// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics). -func buildEmptyTopic() []byte { - return []byte{} // Empty message = no map entries -} - -// ============================================================ -// UnifiedStateSyncUpdate builder -// ============================================================ - -// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set. -// -// message UnifiedStateSyncUpdate { -// oneof update_type { -// Topic initial_state = 1; -// AppliedUpdate applied_update = 2; -// } -// } -func buildInitialStateUpdate(topicData []byte) []byte { - return encodeProtoBytes(1, topicData) -} - -func buildAppliedUpdate(key string, row []byte) []byte { - var applied []byte - applied = append(applied, encodeProtoString(1, key)...) - if len(row) > 0 { - applied = append(applied, encodeProtoBytes(2, row)...) - } - return encodeProtoBytes(2, applied) -} - -// ============================================================ -// MockExtensionServer -// ============================================================ - -// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the -// Language Server binary connects to. It implements just enough of the -// ExtensionServerService to keep the LS operational. -type MockExtensionServer struct { - listener net.Listener - server *http.Server - port int - csrf string - mu sync.RWMutex - tokens map[string]*TokenInfo // account_id -> token info - credits map[string]*ModelCreditsInfo // account_id -> model credits info - subscribers map[string]map[int]*stateSubscriber - nextSubID int - lastAccountID string - logger *slog.Logger - - // Trajectory callback — when LS pushes trajectory updates, we forward them - onTrajectoryUpdate func(topic, key string, data []byte) -} - -// TokenInfo holds OAuth token details for an account. -type TokenInfo struct { - AccessToken string - RefreshToken string - ExpiresAt time.Time // zero value means unknown; defaults to now+1h -} - -// ModelCreditsInfo mirrors the JS uss-modelCredits topic state. -type ModelCreditsInfo struct { - UseAICredits bool - AvailableCredits *int32 - MinimumCreditAmountForUsage *int32 -} - -type stateSubscriber struct { - id int - accountID string - topic string - updates chan []byte -} - -const ( - useAICreditsSentinelKey = "useAICreditsSentinelKey" - availableCreditsSentinelKey = "availableCreditsSentinelKey" - minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey" - defaultMinimumCreditAmountForUsage = int32(50) -) - -// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling. -func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, fmt.Errorf("listen: %w", err) - } - - m := &MockExtensionServer{ - listener: listener, - port: listener.Addr().(*net.TCPAddr).Port, - csrf: csrf, - tokens: make(map[string]*TokenInfo), - credits: make(map[string]*ModelCreditsInfo), - subscribers: make(map[string]map[int]*stateSubscriber), - logger: slog.Default().With("component", "mock-ext-server"), - } - - mux := http.NewServeMux() - extService := "/exa.extension_server_pb.ExtensionServerService/" - - // Register all RPCs the LS calls on the Extension Server. - // Unary RPCs — return application/proto - mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted)) - mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat)) - mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue)) - mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue)) - mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled)) - mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate)) - mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError)) - mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent)) - mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries)) - mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault)) - mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault)) - - // Server-streaming RPCs — return application/connect+proto - mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic)) - mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand)) - - // Catch-all for any unregistered RPCs - mux.HandleFunc("/", m.handleCatchAll) - - m.server = &http.Server{Handler: mux} - - go func() { - if err := m.server.Serve(listener); err != http.ErrServerClosed { - m.logger.Error("extension server error", "err", err) - } - }() - - m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf)) - return m, nil -} - -// Port returns the listening port. -func (m *MockExtensionServer) Port() int { - return m.port -} - -// SetToken sets the OAuth token for an account. -func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) { - m.mu.Lock() - m.tokens[accountID] = info - m.lastAccountID = accountID - subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID) - m.mu.Unlock() - - if info == nil { - return - } - tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt) - tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) - m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64))) -} - -// SetModelCredits sets the uss-modelCredits state for an account. -func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) { - if info == nil { - info = &ModelCreditsInfo{} - } - copyInfo := *info - m.mu.Lock() - m.credits[accountID] = ©Info - m.lastAccountID = accountID - subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID) - m.mu.Unlock() - - m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...) -} - -// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data. -func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) { - m.onTrajectoryUpdate = fn -} - -func (m *MockExtensionServer) currentTokenLocked() *TokenInfo { - if m.lastAccountID != "" { - if info := m.tokens[m.lastAccountID]; info != nil { - return info - } - } - for _, info := range m.tokens { - return info - } - return nil -} - -func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo { - if m.lastAccountID != "" { - if info := m.credits[m.lastAccountID]; info != nil { - return info - } - } - for _, info := range m.credits { - return info - } - return nil -} - -func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo { - if accountID != "" { - if info := m.tokens[accountID]; info != nil { - return info - } - } - return m.currentTokenLocked() -} - -func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo { - if accountID != "" { - if info := m.credits[accountID]; info != nil { - return info - } - } - return m.currentModelCreditsLocked() -} - -func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber { - topicSubs := m.subscribers[topic] - if len(topicSubs) == 0 { - return nil - } - out := make([]*stateSubscriber, 0, len(topicSubs)) - for _, sub := range topicSubs { - if sub == nil { - continue - } - if accountID != "" && sub.accountID != "" && sub.accountID != accountID { - continue - } - out = append(out, sub) - } - return out -} - -func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) { - for _, sub := range subscribers { - if sub == nil { - continue - } - for _, update := range updates { - if len(update) == 0 { - continue - } - payload := append([]byte(nil), update...) - select { - case sub.updates <- payload: - default: - m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID) - } - } - } -} - -func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte { - if info == nil { - info = &ModelCreditsInfo{} - } - minimum := defaultMinimumCreditAmountForUsage - if info.MinimumCreditAmountForUsage != nil { - minimum = *info.MinimumCreditAmountForUsage - } - - updates := make([][]byte, 0, 3) - updates = append(updates, buildAppliedUpdate( - useAICreditsSentinelKey, - buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)), - )) - - if info.UseAICredits { - credits := int32(9999) - if info.AvailableCredits != nil { - credits = *info.AvailableCredits - } - updates = append(updates, buildAppliedUpdate( - availableCreditsSentinelKey, - buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)), - )) - } else { - updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil)) - } - updates = append(updates, buildAppliedUpdate( - minimumCreditAmountForUsageKey, - buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)), - )) - - return updates -} - -// Close shuts down the server. -func (m *MockExtensionServer) Close() error { - return m.server.Close() -} - -// ============================================================ -// Middleware -// ============================================================ - -type unaryHandler func(body []byte) []byte -type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request) - -// handleUnary wraps a unary RPC handler with CSRF check and proper content-type. -func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // CSRF check - if !m.checkCSRF(w, r) { - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - m.logger.Error("read body", "err", err, "path", r.URL.Path) - w.Header().Set("Content-Type", "application/proto") - w.WriteHeader(200) - return - } - - // The LS might send with envelope framing (application/connect+proto) - // or without (application/proto). Detect and unwrap. - ct := r.Header.Get("Content-Type") - protoBody := body - if strings.Contains(ct, "connect+proto") && len(body) >= 5 { - protoBody = unwrapConnectEnvelope(body) - } - - m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct) - - responseProto := handler(protoBody) - - // Respond with proper unary ConnectRPC content-type. - // If the request used "connect+proto", the response should be "application/proto" - // for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto). - w.Header().Set("Content-Type", "application/proto") - w.WriteHeader(200) - if len(responseProto) > 0 { - w.Write(responseProto) - } - } -} - -// handleStream wraps a server-streaming RPC handler with CSRF and content-type. -func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if !m.checkCSRF(w, r) { - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - m.logger.Error("read body", "err", err, "path", r.URL.Path) - return - } - - // Unwrap envelope from request - ct := r.Header.Get("Content-Type") - if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") { - body = unwrapConnectEnvelope(body) - } - - m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body)) - - // Set streaming response content-type - w.Header().Set("Content-Type", "application/connect+proto") - w.WriteHeader(200) - - handler(body, w, r) - } -} - -func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool { - token := r.Header.Get("x-codeium-csrf-token") - if m.csrf != "" && token != m.csrf { - m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))]) - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(403) - w.Write([]byte("Invalid CSRF token")) - return false - } - return true -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} - -// ============================================================ -// Unary RPC Handlers — each receives raw proto request body, -// returns raw proto response body. -// ============================================================ - -func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte { - // LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4) - // We just log the ports — they're informational. - m.logger.Info("LanguageServerStarted", - "body_len", len(body)) - // Return empty LanguageServerStartedResponse - return nil -} - -func (m *MockExtensionServer) onHeartbeat(body []byte) []byte { - // Return empty HeartbeatResponse - return nil -} - -func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte { - // GetSecretValueRequest: key = field 1 - key := decodeProtoString(body, 1) - m.logger.Debug("GetSecretValue", "key", key) - - m.mu.RLock() - var token string - if info := m.currentTokenLocked(); info != nil { - token = info.AccessToken - } - m.mu.RUnlock() - - // GetSecretValueResponse: value = field 1 - if token != "" { - return encodeProtoString(1, token) - } - return nil -} - -func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte { - key := decodeProtoString(body, 1) - m.logger.Debug("StoreSecretValue", "key", key) - return nil -} - -func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte { - // IsAgentManagerEnabledResponse: enabled = field 1 (bool) - return encodeProtoBool(1, false) -} - -func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte { - // PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message) - // UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2 - m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body)) - - // Extract topic name from the embedded UpdateRequest - // The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest - // We need to dig into the nested message to get topic_name - if m.onTrajectoryUpdate != nil { - // For now, just notify that an update was pushed - m.onTrajectoryUpdate("", "", body) - } - - // Return empty PushUnifiedStateSyncUpdateResponse - return nil -} - -func (m *MockExtensionServer) onRecordError(body []byte) []byte { - m.logger.Debug("RecordError", "body_len", len(body)) - return nil -} - -func (m *MockExtensionServer) onLogEvent(body []byte) []byte { - return nil -} - -func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte { - m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body)) - return nil -} - -func (m *MockExtensionServer) onDefault(body []byte) []byte { - return nil -} - -// ============================================================ -// Streaming RPC Handlers -// ============================================================ - -func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) { - // SubscribeToUnifiedStateSyncTopicRequest: topic = field 1 - topic := decodeProtoString(body, 1) - m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic) - - flusher, ok := w.(http.Flusher) - if !ok { - m.logger.Error("ResponseWriter does not support Flush") - return - } - - m.mu.Lock() - accountID := m.lastAccountID - subID := m.nextSubID - m.nextSubID++ - sub := &stateSubscriber{ - id: subID, - accountID: accountID, - topic: topic, - updates: make(chan []byte, 16), - } - if m.subscribers[topic] == nil { - m.subscribers[topic] = make(map[int]*stateSubscriber) - } - m.subscribers[topic][subID] = sub - - // Build initial state based on topic - var topicData []byte - switch topic { - case "uss-oauth": - tokenInfo := m.tokenForAccountLocked(accountID) - if tokenInfo != nil { - topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt) - } else { - topicData = buildEmptyTopic() - } - case "uss-modelCredits": - creditsInfo := m.creditsForAccountLocked(accountID) - if creditsInfo != nil { - topicData = buildUSSTopicWithModelCredits(creditsInfo) - } else { - topicData = buildEmptyTopic() - } - default: - // For all other topics (browserPreferences, enterprisePreferences, etc.), - // return empty topic data. - topicData = buildEmptyTopic() - } - m.mu.Unlock() - defer func() { - m.mu.Lock() - if topicSubs := m.subscribers[topic]; topicSubs != nil { - delete(topicSubs, subID) - if len(topicSubs) == 0 { - delete(m.subscribers, topic) - } - } - m.mu.Unlock() - }() - - // Send initial state as envelope-framed message - initialUpdate := buildInitialStateUpdate(topicData) - frame := connectEnvelope(0x00, initialUpdate) - w.Write(frame) - flusher.Flush() - - for { - select { - case <-r.Context().Done(): - m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic) - return - case update := <-sub.updates: - if len(update) == 0 { - continue - } - if _, err := w.Write(connectEnvelope(0x00, update)); err != nil { - m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err) - return - } - flusher.Flush() - } - } -} - -func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) { - m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body)) - // Send end-of-stream immediately — we don't execute commands - flusher, ok := w.(http.Flusher) - if !ok { - return - } - w.Write(connectEndOfStream()) - flusher.Flush() -} - -// ============================================================ -// Catch-all handler -// ============================================================ - -func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) { - if !m.checkCSRF(w, r) { - return - } - m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method) - - // Drain request body - io.ReadAll(r.Body) - - // Determine if this is likely a unary or streaming request based on content-type. - ct := r.Header.Get("Content-Type") - if strings.Contains(ct, "connect+") { - // Could be streaming — respond with unary proto to be safe - // (unary Connect requests can also use connect+ prefix in some client impls) - w.Header().Set("Content-Type", "application/proto") - } else { - w.Header().Set("Content-Type", "application/proto") - } - w.WriteHeader(200) -} diff --git a/backend/internal/pkg/lspool/pool.go b/backend/internal/pkg/lspool/pool.go deleted file mode 100644 index 209a95c2..00000000 --- a/backend/internal/pkg/lspool/pool.go +++ /dev/null @@ -1,1186 +0,0 @@ -// Package lspool manages a pool of AntiGravity Language Server instances. -// -// Each Google account gets its own LS instance. The LS binary is Google's own -// compiled Go binary, so all upstream TLS fingerprints, session behavior, -// and protocol patterns are 100% authentic — indistinguishable from real IDE. -// -// Architecture: -// -// sub2API Gateway → LS Pool → LS Instance (per account) → cloudcode-pa -// -// Communication protocol (from JS source analysis): -// -// sub2API → LS: ConnectRPC over HTTPS/2, binary proto, x-codeium-csrf-token header -// LS → ExtServer: ConnectRPC over HTTP/1.1, binary proto, x-codeium-csrf-token header -// -// Unary calls: Content-Type: application/proto (no envelope framing) -// Stream calls: Content-Type: application/connect+proto (envelope-framed) -// Envelope = 1 byte flags + 4 byte BE length + payload -// flags=0x02 means end-of-stream trailer -package lspool - -import ( - "bufio" - "bytes" - "context" - crand "crypto/rand" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/binary" - "encoding/json" - "fmt" - "io" - "log/slog" - "net/http" - "os" - "os/exec" - "path/filepath" - "regexp" - "runtime" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "golang.org/x/net/http2" - - "github.com/Wei-Shaw/sub2api/internal/util/logredact" -) - -// ============================================================ -// Configuration -// ============================================================ - -// Config for the LS pool -type Config struct { - // AppRoot is the path to AntiGravity.app resources - // e.g., "/Applications/AntiGravity.app/Contents/Resources/app" - AppRoot string - - // CloudCodeEndpoint overrides the default Cloud Code endpoint - CloudCodeEndpoint string - - // MaxIdleTime before shutting down an idle LS instance - MaxIdleTime time.Duration - - // HealthCheckInterval between health checks - HealthCheckInterval time.Duration - - // ReplicasPerAccount controls how many LS processes a single account can use. - ReplicasPerAccount int -} - -// DefaultConfig returns production defaults -func DefaultConfig() Config { - return Config{ - AppRoot: findAppRoot(), - CloudCodeEndpoint: "https://cloudcode-pa.googleapis.com", - MaxIdleTime: 30 * time.Minute, - HealthCheckInterval: 30 * time.Second, - ReplicasPerAccount: parseLSReplicaCount(), - } -} - -func parseLSReplicaCount() int { - raw := strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT")) - if raw == "" { - return 5 - } - val, err := strconv.Atoi(raw) - if err != nil || val < 1 { - return 5 - } - return val -} - -func findAppRoot() string { - candidates := []string{ - "/Applications/AntiGravity.app/Contents/Resources/app", - "/Applications/Antigravity.app/Contents/Resources/app", - filepath.Join(os.Getenv("HOME"), ".local/share/antigravity/app"), - } - for _, c := range candidates { - if _, err := os.Stat(filepath.Join(c, "extensions", "antigravity", "bin")); err == nil { - return c - } - } - return candidates[0] -} - -// ============================================================ -// LS Instance -// ============================================================ - -// maxConcurrencyPerInstance limits how many concurrent Cascade calls a single -// LS instance handles. LS is designed for a single IDE user; beyond this -// threshold requests are rejected so the caller can fall back to direct HTTP. -const ( - maxConcurrencyPerInstance = 5 - lsStartupReadyTimeout = 6 * time.Second - lsStartupProbeInterval = 100 * time.Millisecond - lsStartupHeartbeatTimeout = 1 * time.Second -) - -// Instance represents a single Language Server process bound to one Google account -type Instance struct { - AccountID string - Email string - CSRF string - Replica int - Address string // e.g., "127.0.0.1:52444" - cmd *exec.Cmd - cleanup func() - client *http.Client - mu sync.RWMutex - healthy bool - lastUsed time.Time - startedAt time.Time - inflight int64 // atomic: current number of concurrent cascade calls - modelMapReady int32 - modelMapHard int32 - remote bool - workerToken string - routingKey string - modelMapError string -} - -// AcquireConcurrency atomically increments the inflight counter. -// Returns false if the instance is already at max capacity. -func (i *Instance) AcquireConcurrency() bool { - for { - cur := atomic.LoadInt64(&i.inflight) - if cur >= int64(maxConcurrencyPerInstance) { - return false - } - if atomic.CompareAndSwapInt64(&i.inflight, cur, cur+1) { - return true - } - } -} - -// ReleaseConcurrency decrements the inflight counter. -func (i *Instance) ReleaseConcurrency() { - atomic.AddInt64(&i.inflight, -1) -} - -// ConcurrentCount returns the current number of in-flight cascade calls. -func (i *Instance) ConcurrentCount() int64 { - return atomic.LoadInt64(&i.inflight) -} - -// SetModelMappingReady records whether this LS instance has successfully loaded -// its model config from the upstream service. -func (i *Instance) SetModelMappingReady(ready bool) { - if ready { - atomic.StoreInt32(&i.modelMapReady, 1) - return - } - atomic.StoreInt32(&i.modelMapReady, 0) -} - -// SetModelMappingUnavailable marks the instance as unable to load model config -// with the current token/client combination. -func (i *Instance) SetModelMappingUnavailable(reason string) { - atomic.StoreInt32(&i.modelMapHard, 1) - i.mu.Lock() - i.modelMapError = strings.TrimSpace(reason) - i.mu.Unlock() -} - -// ClearModelMappingUnavailable resets any previously recorded permanent model -// mapping failure state. -func (i *Instance) ClearModelMappingUnavailable() { - atomic.StoreInt32(&i.modelMapHard, 0) - i.mu.Lock() - i.modelMapError = "" - i.mu.Unlock() -} - -// HasModelMappingUnavailable reports whether model config loading is currently -// known to be incompatible with the account/token. -func (i *Instance) HasModelMappingUnavailable() bool { - return atomic.LoadInt32(&i.modelMapHard) == 1 -} - -// ModelMappingUnavailableReason returns the last recorded permanent failure -// reason, if any. -func (i *Instance) ModelMappingUnavailableReason() string { - i.mu.RLock() - defer i.mu.RUnlock() - return strings.TrimSpace(i.modelMapError) -} - -// HasModelMappingReady reports whether this LS instance has completed model -// config loading successfully. -func (i *Instance) HasModelMappingReady() bool { - return atomic.LoadInt32(&i.modelMapReady) == 1 -} - -// IsHealthy returns whether the instance is healthy -func (i *Instance) IsHealthy() bool { - i.mu.RLock() - defer i.mu.RUnlock() - return i.healthy -} - -// Touch marks the instance as recently used -func (i *Instance) Touch() { - i.mu.Lock() - i.lastUsed = time.Now() - i.mu.Unlock() -} - -// ============================================================ -// RPC Methods — uses ConnectRPC binary proto -// ============================================================ - -const ( - LSService = "exa.language_server_pb.LanguageServerService" -) - -// CallUnaryJSON makes a ConnectRPC JSON unary call (for convenience/debugging). -func (i *Instance) CallUnaryJSON(ctx context.Context, service, method string, input any) ([]byte, error) { - i.Touch() - - if i.remote { - body, err := marshalWorkerJSONBody(input) - if err != nil { - return nil, fmt.Errorf("marshal input: %w", err) - } - return i.callWorkerUnary(ctx, service, method, "json", body) - } - - url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) - - var body []byte - if input != nil { - var err error - body, err = json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("marshal input: %w", err) - } - } else { - body = []byte("{}") - } - - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connect-Protocol-Version", "1") - req.Header.Set("x-codeium-csrf-token", i.CSRF) - - resp, err := i.client.Do(req) - if err != nil { - return nil, fmt.Errorf("rpc %s/%s: %w", service, method, err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - if resp.StatusCode != 200 { - return respBody, fmt.Errorf("rpc %s/%s HTTP %d: %s", - service, method, resp.StatusCode, truncate(string(respBody), 200)) - } - - return respBody, nil -} - -// CallRPC makes a ConnectRPC binary proto unary call to the LS. -// Uses Content-Type: application/proto (Connect protocol unary). -func (i *Instance) CallRPC(ctx context.Context, service, method string, protoBody []byte) ([]byte, error) { - i.Touch() - - if i.remote { - return i.callWorkerUnary(ctx, service, method, "proto", protoBody) - } - - url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) - - if protoBody == nil { - protoBody = []byte{} - } - - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(protoBody)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/proto") - req.Header.Set("Connect-Protocol-Version", "1") - req.Header.Set("x-codeium-csrf-token", i.CSRF) - - resp, err := i.client.Do(req) - if err != nil { - return nil, fmt.Errorf("rpc %s/%s: %w", service, method, err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - if resp.StatusCode != 200 { - return respBody, fmt.Errorf("rpc %s/%s HTTP %d: %s", - service, method, resp.StatusCode, truncate(string(respBody), 200)) - } - - return respBody, nil -} - -// StreamRPC makes a server-streaming ConnectRPC call, returning the raw response. -// Uses Content-Type: application/connect+proto with envelope framing. -func (i *Instance) StreamRPC(ctx context.Context, service, method string, protoBody []byte) (*http.Response, error) { - i.Touch() - - if i.remote { - return i.callWorkerStream(ctx, service, method, "proto", protoBody) - } - - url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) - - if protoBody == nil { - protoBody = []byte{} - } - - // Wrap in Connect envelope for streaming request - framedBody := frameConnectMessage(protoBody) - - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(framedBody)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/connect+proto") - req.Header.Set("Connect-Protocol-Version", "1") - req.Header.Set("Connect-Content-Encoding", "identity") - req.Header.Set("x-codeium-csrf-token", i.CSRF) - - return i.client.Do(req) -} - -// StreamRPCJSON makes a server-streaming ConnectRPC JSON call (for debugging). -func (i *Instance) StreamRPCJSON(ctx context.Context, service, method string, input any) (*http.Response, error) { - i.Touch() - - if i.remote { - body, err := marshalWorkerJSONBody(input) - if err != nil { - return nil, fmt.Errorf("marshal: %w", err) - } - return i.callWorkerStream(ctx, service, method, "json", body) - } - - url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) - - var body []byte - if input != nil { - var err error - body, err = json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("marshal: %w", err) - } - } else { - body = []byte("{}") - } - - // Wrap in Connect envelope for streaming request - body = frameConnectMessage(body) - - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/connect+json") - req.Header.Set("Connect-Protocol-Version", "1") - req.Header.Set("Connect-Content-Encoding", "identity") - req.Header.Set("x-codeium-csrf-token", i.CSRF) - - return i.client.Do(req) -} - -func frameConnectMessage(payload []byte) []byte { - framed := make([]byte, 5+len(payload)) - binary.BigEndian.PutUint32(framed[1:5], uint32(len(payload))) - copy(framed[5:], payload) - return framed -} - -// Heartbeat sends a heartbeat to the LS -func (i *Instance) Heartbeat(ctx context.Context) error { - return i.HeartbeatWithTimeout(ctx, 15*time.Second) -} - -// HeartbeatWithTimeout sends a heartbeat to the LS with an explicit deadline. -func (i *Instance) HeartbeatWithTimeout(ctx context.Context, timeout time.Duration) error { - if timeout <= 0 { - timeout = 15 * time.Second - } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - _, err := i.CallRPC(ctx, LSService, "Heartbeat", nil) - return err -} - -func waitForInstanceReady( - ctx context.Context, - probeInterval time.Duration, - heartbeat func(context.Context) error, -) (int, error) { - if probeInterval <= 0 { - probeInterval = lsStartupProbeInterval - } - if heartbeat == nil { - return 0, fmt.Errorf("heartbeat func is nil") - } - - timer := time.NewTimer(0) - defer timer.Stop() - - attempts := 0 - var lastErr error - - for { - select { - case <-ctx.Done(): - if lastErr == nil { - lastErr = ctx.Err() - } - return attempts, lastErr - case <-timer.C: - } - - attempts++ - if err := heartbeat(ctx); err == nil { - return attempts, nil - } else { - lastErr = err - } - - timer.Reset(probeInterval) - } -} - -// ============================================================ -// Pool Manager -// ============================================================ - -// Pool manages multiple LS instances, with sticky session routing per account. -type Pool struct { - config Config - instances map[string][]*Instance // accountID -> replica slot -> instance - extServer *MockExtensionServer // shared mock extension server - mu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - logger *slog.Logger -} - -// NewPool creates a new LS pool with lifecycle management -func NewPool(config Config) *Pool { - ctx, cancel := context.WithCancel(context.Background()) - - // Generate a shared CSRF token for communication between LS and ext server - csrf := generateUUID() - - // Start the mock extension server - extServer, err := NewMockExtensionServer(csrf) - if err != nil { - slog.Error("failed to start mock extension server", "err", err) - } - - p := &Pool{ - config: config, - instances: make(map[string][]*Instance), - extServer: extServer, - ctx: ctx, - cancel: cancel, - logger: slog.Default().With("component", "lspool"), - } - go p.lifecycleLoop() - return p -} - -func (p *Pool) replicaCount() int { - if p.config.ReplicasPerAccount < 1 { - return 1 - } - return p.config.ReplicasPerAccount -} - -func (p *Pool) ensureLogger() { - if p.logger == nil { - p.logger = slog.Default().With("component", "lspool") - } -} - -func replicaSlotIndex(routingKey string, replicaCount int) int { - if replicaCount <= 1 || strings.TrimSpace(routingKey) == "" { - return 0 - } - sum := sha256.Sum256([]byte(routingKey)) - return int(binary.BigEndian.Uint32(sum[:4]) % uint32(replicaCount)) -} - -func chooseLeastBusyHealthy(instances []*Instance) *Instance { - var best *Instance - for _, inst := range instances { - if inst == nil || !inst.IsHealthy() { - continue - } - if best == nil || inst.ConcurrentCount() < best.ConcurrentCount() { - best = inst - } - } - return best -} - -func (p *Pool) ensureReplicaSlotsLocked(accountID string) []*Instance { - slots := p.instances[accountID] - required := p.replicaCount() - if len(slots) < required { - expanded := make([]*Instance, required) - copy(expanded, slots) - slots = expanded - p.instances[accountID] = slots - } - return slots -} - -// Get returns an existing healthy LS instance for the account and routing key, or nil. -func (p *Pool) Get(accountID, routingKey string) *Instance { - p.mu.RLock() - defer p.mu.RUnlock() - - instances := p.instances[accountID] - if len(instances) == 0 { - return nil - } - if strings.TrimSpace(routingKey) != "" { - slot := replicaSlotIndex(routingKey, p.replicaCount()) - if slot < len(instances) { - inst := instances[slot] - if inst != nil && inst.IsHealthy() { - inst.Touch() - return inst - } - } - return nil - } - if inst := chooseLeastBusyHealthy(instances); inst != nil { - inst.Touch() - return inst - } - return nil -} - -// GetOrCreate returns an existing LS or starts a new one. -// proxyURL is passed to the LS process as HTTPS_PROXY for Google API connectivity. -func (p *Pool) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) { - p.ensureLogger() - if inst := p.Get(accountID, routingKey); inst != nil { - return inst, nil - } - - p.mu.Lock() - defer p.mu.Unlock() - - slots := p.ensureReplicaSlotsLocked(accountID) - slot := replicaSlotIndex(routingKey, p.replicaCount()) - - if strings.TrimSpace(routingKey) == "" { - if inst := chooseLeastBusyHealthy(slots); inst != nil { - inst.Touch() - return inst, nil - } - for idx, inst := range slots { - if inst == nil { - slot = idx - break - } - } - } - - if slot < len(slots) { - if inst := slots[slot]; inst != nil && inst.IsHealthy() { - inst.Touch() - return inst, nil - } - } - - proxy := "" - if len(proxyURL) > 0 { - proxy = proxyURL[0] - } - - if slot < len(slots) && slots[slot] != nil { - p.stopInstance(slots[slot]) - slots[slot] = nil - } - - inst, err := p.startInstance(accountID, proxy, slot) - if err != nil { - return nil, err - } - - slots[slot] = inst - p.logger.Info("LS instance created", - "account", shortAccountID(accountID), - "replica", slot, - "address", inst.Address, - "pid", inst.cmd.Process.Pid) - return inst, nil -} - -// Remove stops and removes all LS instances for an account. -func (p *Pool) Remove(accountID string) { - p.mu.Lock() - defer p.mu.Unlock() - - if slots, ok := p.instances[accountID]; ok { - for _, inst := range slots { - if inst == nil { - continue - } - p.stopInstance(inst) - } - delete(p.instances, accountID) - } -} - -// SetAccountToken updates the OAuth token for an account in the mock extension server -func (p *Pool) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { - if p.extServer != nil { - p.extServer.SetToken(accountID, &TokenInfo{ - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresAt: expiresAt, - }) - } - p.mu.RLock() - slots := append([]*Instance(nil), p.instances[accountID]...) - p.mu.RUnlock() - for _, inst := range slots { - if inst == nil { - continue - } - inst.SetModelMappingReady(false) - inst.ClearModelMappingUnavailable() - } -} - -// SetAccountModelCredits updates the JS-parity uss-modelCredits state for an account. -func (p *Pool) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { - if p.extServer != nil { - p.extServer.SetModelCredits(accountID, &ModelCreditsInfo{ - UseAICredits: useAICredits, - AvailableCredits: availableCredits, - MinimumCreditAmountForUsage: minimumCreditAmountForUsage, - }) - } -} - -// Stats returns pool statistics -func (p *Pool) Stats() map[string]any { - p.mu.RLock() - defer p.mu.RUnlock() - - active := 0 - total := 0 - for _, slots := range p.instances { - for _, inst := range slots { - if inst == nil { - continue - } - total++ - if inst.IsHealthy() { - active++ - } - } - } - return map[string]any{ - "accounts": len(p.instances), - "total": total, - "active": active, - } -} - -// Close shuts down all instances and the extension server -func (p *Pool) Close() { - p.ensureLogger() - p.cancel() - p.mu.Lock() - defer p.mu.Unlock() - - for id, slots := range p.instances { - for _, inst := range slots { - if inst == nil { - continue - } - p.logger.Info("shutting down LS", "account", shortAccountID(id), "replica", inst.Replica) - p.stopInstance(inst) - } - } - p.instances = make(map[string][]*Instance) - - if p.extServer != nil { - p.extServer.Close() - } -} - -// ============================================================ -// Instance Lifecycle -// ============================================================ - -var portRe = regexp.MustCompile(`at (\d+) for HTTPS`) - -func (p *Pool) startInstance(accountID string, proxyURL string, replica int) (*Instance, error) { - binPath := filepath.Join(p.config.AppRoot, "extensions", "antigravity", "bin", lsBinaryName()) - if _, err := os.Stat(binPath); err != nil { - return nil, fmt.Errorf("LS binary not found: %s", binPath) - } - - // Each LS instance gets its own CSRF token (like the real IDE) - csrf := generateUUID() - appDataDir := fmt.Sprintf("antigravity-pool-%s-r%d", shortAccountID(accountID), replica) - - args := []string{ - "--csrf_token", csrf, - "--app_data_dir", appDataDir, - "--https_server_port", "0", - } - if p.config.CloudCodeEndpoint != "" { - args = append(args, "--cloud_code_endpoint", p.config.CloudCodeEndpoint) - } - // Connect LS to our mock extension server for token injection. - // The extension server uses a shared CSRF token (set at server creation). - if p.extServer != nil { - args = append(args, - "--extension_server_port", fmt.Sprintf("%d", p.extServer.Port()), - "--extension_server_csrf_token", p.extServer.csrf, - ) - } - - rawProxyURL := resolveLSProxy(proxyURL) - launchPlan, err := prepareLSLaunchPlan(binPath, args, rawProxyURL) - if err != nil { - return nil, fmt.Errorf("prepare LS launch: %w", err) - } - - cmd := launchPlan.cmd - cmd.Env = buildLSEnv(os.Environ(), p.config.AppRoot, launchPlan.effectiveProxyURL) - p.logger.Info("LS starting", - "account", shortAccountID(accountID), - "replica", replica, - "proxy_source", logredact.RedactProxyURL(rawProxyURL), - "proxy_mode", launchPlan.proxyMode, - "effective_proxy", logredact.RedactProxyURL(launchPlan.effectiveProxyURL)) - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, fmt.Errorf("stdin pipe: %w", err) - } - stderr, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("stderr pipe: %w", err) - } - - if err := cmd.Start(); err != nil { - if launchPlan.cleanup != nil { - launchPlan.cleanup() - } - return nil, fmt.Errorf("start LS: %w", err) - } - - // Write metadata proto to stdin (with access token if available) - accessToken := "" - if p.extServer != nil { - p.extServer.mu.RLock() - if info := p.extServer.tokens[accountID]; info != nil { - accessToken = info.AccessToken - } - p.extServer.mu.RUnlock() - } - metaBytes := buildMetadataBytes(accessToken) - p.logger.Info("writing metadata to LS stdin", - "account", shortAccountID(accountID), - "replica", replica, - "meta_len", len(metaBytes), - "has_token", accessToken != "", - "hex_prefix", fmt.Sprintf("%x", metaBytes[:min(40, len(metaBytes))])) - stdin.Write(metaBytes) - stdin.Close() - - // Parse HTTPS port from stderr AND log all LS output - portCh := make(chan string, 1) - go func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - line := scanner.Text() - p.logger.Warn("LS stderr", "account", shortAccountID(accountID), "replica", replica, "line", line) - if matches := portRe.FindStringSubmatch(line); len(matches) > 1 { - portCh <- matches[1] - } - } - p.logger.Warn("LS stderr EOF", "account", shortAccountID(accountID), "replica", replica) - }() - - var address string - select { - case port := <-portCh: - address = "127.0.0.1:" + port - case <-time.After(15 * time.Second): - cmd.Process.Kill() - return nil, fmt.Errorf("timeout: LS did not report HTTPS port within 15s") - } - - // Create HTTPS client with LS self-signed cert - httpClient, err := createHTTPClient(p.config.AppRoot, csrf) - if err != nil { - cmd.Process.Kill() - return nil, fmt.Errorf("create http client: %w", err) - } - - inst := &Instance{ - AccountID: accountID, - CSRF: csrf, - Replica: replica, - Address: address, - cmd: cmd, - cleanup: launchPlan.cleanup, - client: httpClient, - healthy: false, - lastUsed: time.Now(), - startedAt: time.Now(), - } - - // Real IDE waits for LanguageServerStarted callback from ExtServer (timeout 60s). - // Our MockExtServer already received this callback during port detection - // (LS calls LanguageServerStarted right after binding HTTPS port). - // Probe readiness immediately and keep retrying with a short interval instead - // of sleeping a fixed 3-6 seconds on every cold start. - p.logger.Info("waiting for LS readiness", "account", shortAccountID(accountID), "replica", replica, "address", address) - readyStartedAt := time.Now() - readyCtx, cancel := context.WithTimeout(context.Background(), lsStartupReadyTimeout) - attempts, err := waitForInstanceReady(readyCtx, lsStartupProbeInterval, func(callCtx context.Context) error { - return inst.HeartbeatWithTimeout(callCtx, lsStartupHeartbeatTimeout) - }) - cancel() - if err != nil { - cmd.Process.Kill() - return nil, fmt.Errorf("LS not ready after startup: %w", err) - } - inst.mu.Lock() - inst.healthy = true - inst.mu.Unlock() - p.logger.Info("LS ready", - "account", shortAccountID(accountID), - "replica", replica, - "attempts", attempts, - "waited", time.Since(readyStartedAt).Truncate(time.Millisecond)) - - // Refresh model mapping from LS (async with retries — don't block startup) - go func() { - for attempt := 1; attempt <= 5; attempt++ { - if RefreshModelMapping(inst) { - p.logger.Info("model mapping loaded", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt) - return - } - if inst.HasModelMappingUnavailable() { - p.logger.Warn("model mapping unavailable", - "account", shortAccountID(accountID), - "replica", replica, - "attempt", attempt, - "reason", truncate(inst.ModelMappingUnavailableReason(), 200)) - return - } - p.logger.Warn("model mapping not loaded, retrying", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt) - time.Sleep(time.Duration(attempt*10) * time.Second) - } - }() - - return inst, nil -} - -func (p *Pool) stopInstance(inst *Instance) { - if inst.cmd != nil && inst.cmd.Process != nil { - inst.cmd.Process.Kill() - inst.cmd.Wait() - } - if inst.cleanup != nil { - inst.cleanup() - } - inst.mu.Lock() - inst.healthy = false - inst.mu.Unlock() -} - -func (p *Pool) lifecycleLoop() { - ticker := time.NewTicker(p.config.HealthCheckInterval) - defer ticker.Stop() - - for { - select { - case <-p.ctx.Done(): - return - case <-ticker.C: - p.doHealthCheck() - } - } -} - -func (p *Pool) doHealthCheck() { - p.mu.Lock() - defer p.mu.Unlock() - - for id, slots := range p.instances { - for replica, inst := range slots { - if inst == nil { - continue - } - - // Check process alive - if inst.cmd.ProcessState != nil { - p.logger.Warn("LS process exited", "account", shortAccountID(id), "replica", replica) - slots[replica] = nil - continue - } - - // Check idle timeout - inst.mu.RLock() - idle := time.Since(inst.lastUsed) - inst.mu.RUnlock() - - if idle > p.config.MaxIdleTime { - p.logger.Info("LS idle timeout", "account", shortAccountID(id), "replica", replica, "idle", idle) - p.stopInstance(inst) - slots[replica] = nil - continue - } - - // Heartbeat check - if err := inst.Heartbeat(p.ctx); err != nil { - p.logger.Warn("heartbeat failed", "account", shortAccountID(id), "replica", replica, "err", err) - inst.mu.Lock() - inst.healthy = false - inst.mu.Unlock() - } else { - inst.mu.Lock() - inst.healthy = true - inst.mu.Unlock() - } - } - - allNil := true - for _, inst := range slots { - if inst != nil { - allNil = false - break - } - } - if allNil { - delete(p.instances, id) - } - } -} - -// ============================================================ -// Helpers -// ============================================================ - -var ( - defaultLSCertFileCandidates = []string{ - "/etc/ssl/certs/ca-certificates.crt", - "/etc/pki/tls/certs/ca-bundle.crt", - "/etc/ssl/cert.pem", - } - defaultLSCertDirCandidates = []string{ - "/etc/ssl/certs", - "/etc/pki/tls/certs", - } -) - -func buildLSEnv(baseEnv []string, appRoot string, proxyURL string) []string { - env := append([]string(nil), baseEnv...) - env = setEnvValue(env, "ANTIGRAVITY_EDITOR_APP_ROOT", appRoot) - - // Set proxy for LS to reach Google APIs. - // MUST always override inherited container proxy (which may be Anthropic-only). - // proxyURL is already fully resolved by the caller. - // Always set — even empty string clears inherited container values - env = setEnvValue(env, "HTTPS_PROXY", proxyURL) - env = setEnvValue(env, "HTTP_PROXY", proxyURL) - env = setEnvValue(env, "ALL_PROXY", proxyURL) - env = setEnvValue(env, "https_proxy", proxyURL) - env = setEnvValue(env, "http_proxy", proxyURL) - env = setEnvValue(env, "all_proxy", proxyURL) - - if !hasEnvKey(env, "SSL_CERT_FILE") { - if certFile := firstExistingPath(defaultLSCertFileCandidates); certFile != "" { - env = setEnvValue(env, "SSL_CERT_FILE", certFile) - } - } - if !hasEnvKey(env, "SSL_CERT_DIR") { - if certDir := firstExistingPath(defaultLSCertDirCandidates); certDir != "" { - env = setEnvValue(env, "SSL_CERT_DIR", certDir) - } - } - - return env -} - -func resolveLSProxy(proxyURL string) string { - if strings.TrimSpace(proxyURL) != "" { - return strings.TrimSpace(proxyURL) - } - return strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_PROXY")) -} - -func firstExistingPath(candidates []string) string { - for _, candidate := range candidates { - if candidate == "" { - continue - } - if _, err := os.Stat(candidate); err == nil { - return candidate - } - } - return "" -} - -func hasEnvKey(env []string, key string) bool { - prefix := key + "=" - for _, entry := range env { - if strings.HasPrefix(entry, prefix) { - return true - } - } - return false -} - -func setEnvValue(env []string, key, value string) []string { - prefix := key + "=" - for i, entry := range env { - if strings.HasPrefix(entry, prefix) { - env[i] = prefix + value - return env - } - } - return append(env, prefix+value) -} - -func createHTTPClient(appRoot, csrf string) (*http.Client, error) { - certPath := filepath.Join(appRoot, "extensions", "antigravity", "dist", "languageServer", "cert.pem") - caCert, err := os.ReadFile(certPath) - if err != nil { - return nil, fmt.Errorf("read cert %s: %w", certPath, err) - } - certPool := x509.NewCertPool() - if !certPool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse cert") - } - - tlsCfg := &tls.Config{ - RootCAs: certPool, - InsecureSkipVerify: true, // LS uses self-signed cert; trust it unconditionally - } - - return &http.Client{ - Transport: &csrfTransport{ - base: &http2.Transport{ - TLSClientConfig: tlsCfg, - ReadIdleTimeout: 30 * time.Second, - }, - csrf: csrf, - }, - Timeout: 5 * time.Minute, - }, nil -} - -type csrfTransport struct { - base http.RoundTripper - csrf string -} - -func (t *csrfTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("x-codeium-csrf-token", t.csrf) - return t.base.RoundTrip(req) -} - -// writeMetadata writes the Metadata proto to LS stdin. -// This matches what the real IDE does (extension.js line 42520-42522): -// -// toBinary(MetadataSchema, create(MetadataSchema, { -// ideName, ideVersion, extensionName, extensionPath, -// locale, deviceFingerprint, apiKey, disableTelemetry, userTierId -// })) -func buildMetadataBytes(accessToken string) []byte { - var buf bytes.Buffer - writeProtoStringField(&buf, 1, "antigravity") // ide_name - writeProtoStringField(&buf, 7, "1.107.0") // ide_version - writeProtoStringField(&buf, 12, "antigravity") // extension_name - writeProtoStringField(&buf, 4, "en") // locale - if accessToken != "" { - writeProtoStringField(&buf, 3, accessToken) // api_key = access_token - } - // disable_telemetry = true (field 6, varint, value 1) - buf.Write([]byte{0x30, 0x01}) - return buf.Bytes() -} - -func writeProtoStringField(buf *bytes.Buffer, fieldNum int, val string) { - writeVarInt(buf, uint64(fieldNum<<3|2)) - writeVarInt(buf, uint64(len(val))) - buf.WriteString(val) -} - -func writeVarInt(buf *bytes.Buffer, v uint64) { - b := make([]byte, binary.MaxVarintLen64) - n := binary.PutUvarint(b, v) - buf.Write(b[:n]) -} - -func lsBinaryName() string { - m := map[string]string{ - "darwin/arm64": "language_server_macos_arm", - "darwin/amd64": "language_server_macos_x64", - "linux/amd64": "language_server_linux_x64", - "linux/arm64": "language_server_linux_arm", - "windows/amd64": "language_server_windows_x64.exe", - } - if name, ok := m[runtime.GOOS+"/"+runtime.GOARCH]; ok { - return name - } - return "language_server_linux_x64" // fallback -} - -func generateUUID() string { - b := make([]byte, 16) - crand.Read(b) - b[6] = (b[6] & 0x0f) | 0x40 - b[8] = (b[8] & 0x3f) | 0x80 - return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) -} - -func truncate(s string, max int) string { - if len(s) <= max { - return s - } - return s[:max] + "..." -} - -func shortAccountID(accountID string) string { - if len(accountID) <= 8 { - return accountID - } - return accountID[:8] -} diff --git a/backend/internal/pkg/lspool/pool_test.go b/backend/internal/pkg/lspool/pool_test.go deleted file mode 100644 index 1e1bbf90..00000000 --- a/backend/internal/pkg/lspool/pool_test.go +++ /dev/null @@ -1,376 +0,0 @@ -package lspool - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "fmt" - "io" - "net/http" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/stretchr/testify/require" -) - -func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) { - env := buildLSEnv([]string{ - "SSL_CERT_FILE=/custom/ca.pem", - "SSL_CERT_DIR=/custom/certs", - }, "/opt/antigravity", "") - require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity") - require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem") - require.Contains(t, env, "SSL_CERT_DIR=/custom/certs") -} - -func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) { - env := buildLSEnv([]string{ - "HTTPS_PROXY=http://old-proxy:8080", - "HTTP_PROXY=http://old-proxy:8080", - "ALL_PROXY=socks5://old-proxy:1080", - "https_proxy=http://old-proxy:8080", - "http_proxy=http://old-proxy:8080", - "all_proxy=socks5://old-proxy:1080", - }, "/opt/antigravity", "") - - require.Contains(t, env, "HTTPS_PROXY=") - require.Contains(t, env, "HTTP_PROXY=") - require.Contains(t, env, "ALL_PROXY=") - require.Contains(t, env, "https_proxy=") - require.Contains(t, env, "http_proxy=") - require.Contains(t, env, "all_proxy=") -} - -func TestShortAccountID(t *testing.T) { - require.Equal(t, "9", shortAccountID("9")) - require.Equal(t, "12345678", shortAccountID("12345678")) - require.Equal(t, "12345678", shortAccountID("123456789")) -} - -func TestFrameConnectMessage(t *testing.T) { - framed := frameConnectMessage([]byte(`{"x":1}`)) - require.Len(t, framed, 5+len(`{"x":1}`)) - require.Equal(t, byte(0), framed[0]) - require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5])) - require.Equal(t, `{"x":1}`, string(framed[5:])) -} - -func TestConnectEnvelope(t *testing.T) { - payload := []byte("hello") - env := connectEnvelope(0x00, payload) - require.Len(t, env, 5+len(payload)) - require.Equal(t, byte(0x00), env[0]) - require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5])) - require.Equal(t, "hello", string(env[5:])) -} - -func TestUnwrapConnectEnvelope(t *testing.T) { - payload := []byte("test data") - env := connectEnvelope(0x00, payload) - unwrapped := unwrapConnectEnvelope(env) - require.Equal(t, payload, unwrapped) - short := []byte{1, 2} - require.Equal(t, short, unwrapConnectEnvelope(short)) -} - -func TestExtractPromptAndModel(t *testing.T) { - body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}` - prompt, model := extractPromptAndModel([]byte(body)) - require.Equal(t, "hello world", prompt) - require.Equal(t, "gemini-2.5-pro", model) - - body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}` - prompt2, _ := extractPromptAndModel([]byte(body2)) - require.Equal(t, "test prompt", prompt2) -} - -func TestResolveModelEnum(t *testing.T) { - // Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash) - require.True(t, resolveModelEnum("gemini-2.5-flash") > 0) - require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0) - require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0) - require.True(t, resolveModelEnum("unknown-model") > 0) -} - -func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) { - cfg := buildCascadeConfig("models/gemini-2.5-flash") - require.NotNil(t, cfg) - - plannerConfig, ok := cfg["plannerConfig"].(map[string]any) - require.True(t, ok) - requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) - require.True(t, ok) - require.NotEmpty(t, requestedModel["model"]) - require.Len(t, plannerConfig, 1) -} - -func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) { - cfg := buildCascadeConfig("claude-sonnet-4-6") - require.NotNil(t, cfg) - - plannerConfig, ok := cfg["plannerConfig"].(map[string]any) - require.True(t, ok) - requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) - require.True(t, ok) - require.NotEmpty(t, requestedModel["model"]) - require.Len(t, plannerConfig, 1) -} - -func TestDoNonStreamGeneratePassesThrough(t *testing.T) { - fallback := &recordingUpstream{} - upstream := NewLSPoolUpstream(&Pool{}, fallback) - req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`))) - resp, err := upstream.Do(req, "", 1, 1) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, 1, fallback.doCalls) -} - -func TestExtractPlannerResponseText(t *testing.T) { - resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[ - {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}, - {"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE", - "plannerResponse":{"response":"Hello world"}} - ]}}` - text, generating, status := extractPlannerResponseText([]byte(resp)) - require.Equal(t, "Hello world", text) - require.False(t, generating) - require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status) -} - -func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) { - resp := `{ - "status":"CASCADE_RUN_STATUS_IDLE", - "trajectory":{ - "steps":[ - {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"} - ], - "executorMetadata":{ - "terminationReason":"ERROR", - "errorDetails":{ - "errorCode":429, - "shortError":"Model quota reached", - "details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s." - } - } - } - }` - - state := extractPlannerResponseState([]byte(resp)) - require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status) - require.False(t, state.Generating) - require.Empty(t, state.Text) - require.Contains(t, state.ErrorMessage, "Model quota reached") - require.Contains(t, state.ErrorMessage, "quota will reset after") -} - -func TestBuildGeminiSSEChunk(t *testing.T) { - sse := buildGeminiSSEChunk("hello") - require.Contains(t, sse, "data: ") - require.Contains(t, sse, `"text":"hello"`) - require.Contains(t, sse, `"role":"model"`) - require.True(t, strings.HasSuffix(sse, "\n\n")) -} - -func TestRequestHasTools(t *testing.T) { - // Wrapped format with tools - require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`))) - - // Direct format with tools - require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`))) - - // No tools - require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`))) - - // Empty tools array - require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`))) -} - -func TestCurrentLSStrategy(t *testing.T) { - t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity") - require.Equal(t, LSStrategyJSParity, CurrentLSStrategy()) - - t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown") - require.Equal(t, LSStrategyDirect, CurrentLSStrategy()) -} - -func TestIsPermanentModelMappingError(t *testing.T) { - require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`))) - require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded"))) -} - -func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) { - pool := &Pool{ - instances: map[string][]*Instance{ - "9": { - {AccountID: "9", Replica: 0}, - }, - }, - } - inst := pool.instances["9"][0] - inst.SetModelMappingReady(true) - inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`) - - pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour)) - - require.False(t, inst.HasModelMappingReady()) - require.False(t, inst.HasModelMappingUnavailable()) - require.Empty(t, inst.ModelMappingUnavailableReason()) -} - -func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) { - require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied))) - require.False(t, shouldFallbackDirect(errLSModelMapPending)) -} - -func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) { - t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "") - require.Equal(t, 5, parseLSReplicaCount()) - - t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3") - require.Equal(t, 3, parseLSReplicaCount()) - - t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0") - require.Equal(t, 5, parseLSReplicaCount()) -} - -func TestPoolGetUsesStickyReplicaSlot(t *testing.T) { - pool := &Pool{ - config: Config{ReplicasPerAccount: 5}, - instances: map[string][]*Instance{ - "acc-1": { - {AccountID: "acc-1", Replica: 0, healthy: true}, - {AccountID: "acc-1", Replica: 1, healthy: true}, - {AccountID: "acc-1", Replica: 2, healthy: true}, - {AccountID: "acc-1", Replica: 3, healthy: true}, - {AccountID: "acc-1", Replica: 4, healthy: true}, - }, - }, - } - - routingKey := "acc-1:user-a:session-1" - slot := replicaSlotIndex(routingKey, pool.replicaCount()) - inst := pool.Get("acc-1", routingKey) - require.NotNil(t, inst) - require.Equal(t, slot, inst.Replica) -} - -func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) { - busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true} - atomic.StoreInt64(&busy.inflight, 4) - idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true} - atomic.StoreInt64(&idle.inflight, 1) - - pool := &Pool{ - config: Config{ReplicasPerAccount: 5}, - instances: map[string][]*Instance{ - "acc-1": {busy, idle}, - }, - } - - inst := pool.Get("acc-1", "") - require.NotNil(t, inst) - require.Equal(t, 1, inst.Replica) -} - -func TestWaitForInstanceReadyProbesImmediately(t *testing.T) { - startedAt := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error { - return nil - }) - require.NoError(t, err) - require.Equal(t, 1, attempts) - require.Less(t, time.Since(startedAt), 100*time.Millisecond) -} - -func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - calls := 0 - attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error { - calls++ - if calls < 3 { - return errors.New("not ready") - } - return nil - }) - require.NoError(t, err) - require.Equal(t, 3, attempts) - require.Equal(t, 3, calls) -} - -func TestDecideJSParityRoute(t *testing.T) { - body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) - parsed, err := parseGeminiRequest(body) - require.NoError(t, err) - decision := decideJSParityRoute(parsed, body) - require.True(t, decision.UseLS) - - imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`) - parsedImage, err := parseGeminiRequest(imageBody) - require.NoError(t, err) - decisionImage := decideJSParityRoute(parsedImage, imageBody) - require.False(t, decisionImage.UseLS) - require.Contains(t, strings.ToLower(decisionImage.Reason), "image") - - noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) - parsedNoSession, err := parseGeminiRequest(noSessionBody) - require.NoError(t, err) - require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS) -} - -func TestUserNamespacePrefersExplicitHeader(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "https://example.com", nil) - require.NoError(t, err) - req.Header.Set(userNamespaceHeader, "tenant-a") - req.Header.Set("Authorization", "Bearer oauth-token") - - nsWithExplicit := userNamespace(req) - require.NotEqual(t, "anon", nsWithExplicit) - - req.Header.Del(userNamespaceHeader) - nsWithAuth := userNamespace(req) - require.NotEqual(t, "anon", nsWithAuth) - require.NotEqual(t, nsWithExplicit, nsWithAuth) -} - -func TestConversationPrefixEqual(t *testing.T) { - prefix := []geminiConversationTurn{ - {Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}}, - {Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}}, - } - full := append(cloneConversationTurns(prefix), geminiConversationTurn{ - Role: "user", - Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}}, - }) - require.True(t, conversationPrefixEqual(full, prefix)) - require.False(t, conversationPrefixEqual(prefix, full)) -} - -func TestSystemTextCompatible(t *testing.T) { - require.True(t, systemTextCompatible("You are helpful", "")) - require.True(t, systemTextCompatible("You are helpful", "You are helpful")) - require.False(t, systemTextCompatible("", "You are helpful")) - require.False(t, systemTextCompatible("You are helpful", "You are different")) -} - -type recordingUpstream struct { - doCalls int -} - -func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - r.doCalls++ - return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil -} - -func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) { - return r.Do(req, proxyURL, accountID, c) -} diff --git a/backend/internal/pkg/lspool/proxy_bridge.go b/backend/internal/pkg/lspool/proxy_bridge.go deleted file mode 100644 index 36f5348b..00000000 --- a/backend/internal/pkg/lspool/proxy_bridge.go +++ /dev/null @@ -1,268 +0,0 @@ -package lspool - -import ( - "bufio" - "context" - "fmt" - "io" - "log/slog" - "net" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" - "golang.org/x/net/proxy" -) - -type lsProxyBridge struct { - listener net.Listener - server *http.Server - url string - upstream string -} - -type lsProxyBridgeManager struct { - mu sync.Mutex - bridges map[string]*lsProxyBridge - logger *slog.Logger -} - -var globalLSProxyBridgeManager = &lsProxyBridgeManager{ - bridges: make(map[string]*lsProxyBridge), - logger: slog.Default().With("component", "lspool-proxy-bridge"), -} - -var ( - lsProxyBridgeDialTimeout = 10 * time.Second - lsProxyBridgeProbeTargets = []string{ - "cloudcode-pa.googleapis.com:443", - "oauthaccountmanager.googleapis.com:443", - } -) - -func prepareLSProxyURL(raw string) (string, error) { - normalized, parsed, err := proxyurl.Parse(raw) - if err != nil { - return "", err - } - if parsed == nil { - return "", nil - } - - switch strings.ToLower(parsed.Scheme) { - case "http", "https": - return normalized, nil - case "socks5", "socks5h": - return globalLSProxyBridgeManager.ensure(normalized, parsed) - default: - return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme) - } -} - -func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if bridge := m.bridges[key]; bridge != nil { - return bridge.url, nil - } - - bridge, err := newLSProxyBridge(upstream, m.logger) - if err != nil { - return "", err - } - m.bridges[key] = bridge - return bridge.url, nil -} - -func (m *lsProxyBridgeManager) closeAll() { - m.mu.Lock() - defer m.mu.Unlock() - - for key, bridge := range m.bridges { - if bridge != nil { - _ = bridge.server.Close() - _ = bridge.listener.Close() - } - delete(m.bridges, key) - } -} - -func closeAllLSProxyBridgesForTest() { - globalLSProxyBridgeManager.closeAll() -} - -func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) { - dialer, err := proxy.FromURL(upstream, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("create SOCKS dialer: %w", err) - } - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, fmt.Errorf("listen LS proxy bridge: %w", err) - } - - bridge := &lsProxyBridge{ - listener: listener, - url: "http://" + listener.Addr().String(), - upstream: upstream.Redacted(), - } - - server := &http.Server{ - Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)), - ReadHeaderTimeout: 10 * time.Second, - IdleTimeout: 2 * time.Minute, - } - bridge.server = server - - go func() { - if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { - logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err) - } - }() - - logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url) - go bridge.probeConnectivity(dialer, logger) - return bridge, nil -} - -func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodConnect { - http.Error(w, "CONNECT only", http.StatusMethodNotAllowed) - return - } - - targetAddr := strings.TrimSpace(r.Host) - if targetAddr == "" { - targetAddr = strings.TrimSpace(r.URL.Host) - } - if targetAddr == "" { - http.Error(w, "missing target host", http.StatusBadRequest) - return - } - if _, _, err := net.SplitHostPort(targetAddr); err != nil { - targetAddr = net.JoinHostPort(targetAddr, "443") - } - - startedAt := time.Now() - logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr) - - dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout) - defer cancel() - - targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr) - if err != nil { - logger.Warn("LS proxy bridge dial failed", - "upstream", b.upstream, - "target", targetAddr, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond), - "err", err) - http.Error(w, "proxy dial failed", http.StatusBadGateway) - return - } - logger.Info("LS proxy bridge CONNECT established", - "upstream", b.upstream, - "target", targetAddr, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) - - hijacker, ok := w.(http.Hijacker) - if !ok { - _ = targetConn.Close() - http.Error(w, "hijack unsupported", http.StatusInternalServerError) - return - } - - clientConn, rw, err := hijacker.Hijack() - if err != nil { - _ = targetConn.Close() - return - } - - if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { - _ = targetConn.Close() - _ = clientConn.Close() - return - } - - if rw != nil && rw.Reader.Buffered() > 0 { - if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil { - _ = targetConn.Close() - _ = clientConn.Close() - return - } - } - - tunnelConns(clientConn, targetConn) - } -} - -func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) { - if contextDialer, ok := dialer.(proxy.ContextDialer); ok { - return contextDialer.DialContext(ctx, "tcp", targetAddr) - } - - type dialResult struct { - conn net.Conn - err error - } - resultCh := make(chan dialResult, 1) - go func() { - conn, err := dialer.Dial("tcp", targetAddr) - resultCh <- dialResult{conn: conn, err: err} - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case result := <-resultCh: - return result.conn, result.err - } -} - -func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) { - for _, targetAddr := range lsProxyBridgeProbeTargets { - startedAt := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout) - conn, err := dialViaProxy(ctx, dialer, targetAddr) - cancel() - if err != nil { - logger.Warn("LS proxy bridge probe failed", - "upstream", b.upstream, - "target", targetAddr, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond), - "err", err) - continue - } - _ = conn.Close() - logger.Info("LS proxy bridge probe succeeded", - "upstream", b.upstream, - "target", targetAddr, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) - } -} - -func tunnelConns(clientConn net.Conn, targetConn net.Conn) { - var once sync.Once - closeBoth := func() { - _ = clientConn.Close() - _ = targetConn.Close() - } - - go func() { - _, _ = io.Copy(targetConn, clientConn) - once.Do(closeBoth) - }() - go func() { - _, _ = io.Copy(clientConn, targetConn) - once.Do(closeBoth) - }() -} - -func readConnectResponse(br *bufio.Reader) (*http.Response, error) { - return http.ReadResponse(br, &http.Request{Method: http.MethodConnect}) -} diff --git a/backend/internal/pkg/lspool/proxy_bridge_test.go b/backend/internal/pkg/lspool/proxy_bridge_test.go deleted file mode 100644 index 69960e3a..00000000 --- a/backend/internal/pkg/lspool/proxy_bridge_test.go +++ /dev/null @@ -1,193 +0,0 @@ -package lspool - -import ( - "bufio" - "encoding/binary" - "fmt" - "io" - "net" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) { - t.Cleanup(closeAllLSProxyBridgesForTest) - - got, err := prepareLSProxyURL("http://proxy.example.com:8080") - require.NoError(t, err) - require.Equal(t, "http://proxy.example.com:8080", got) -} - -func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) { - t.Cleanup(closeAllLSProxyBridgesForTest) - - targetAddr, closeTarget := startBridgeEchoServer(t) - defer closeTarget() - - socksURL, closeSOCKS := startBridgeSOCKS5Server(t) - defer closeSOCKS() - - bridgeURL, err := prepareLSProxyURL(socksURL) - require.NoError(t, err) - require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:")) - - // Same SOCKS upstream should reuse the same local bridge. - reusedURL, err := prepareLSProxyURL(socksURL) - require.NoError(t, err) - require.Equal(t, bridgeURL, reusedURL) - - bridgeAddr := strings.TrimPrefix(bridgeURL, "http://") - conn, err := net.Dial("tcp", bridgeAddr) - require.NoError(t, err) - defer conn.Close() - - _, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr) - require.NoError(t, err) - - reader := bufio.NewReader(conn) - resp, err := readConnectResponse(reader) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) - - _, err = conn.Write([]byte("ping")) - require.NoError(t, err) - - reply := make([]byte, 4) - _, err = io.ReadFull(reader, reply) - require.NoError(t, err) - require.Equal(t, "pong", string(reply)) -} - -func startBridgeEchoServer(t *testing.T) (string, func()) { - t.Helper() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - for { - conn, err := ln.Accept() - if err != nil { - return - } - go func(c net.Conn) { - defer c.Close() - buf := make([]byte, 4) - if _, err := io.ReadFull(c, buf); err != nil { - return - } - if string(buf) == "ping" { - _, _ = c.Write([]byte("pong")) - } - }(conn) - } - }() - - return ln.Addr().String(), func() { - _ = ln.Close() - <-done - } -} - -func startBridgeSOCKS5Server(t *testing.T) (string, func()) { - t.Helper() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - - done := make(chan struct{}) - go func() { - defer close(done) - for { - conn, err := ln.Accept() - if err != nil { - return - } - go handleBridgeSOCKS5Conn(conn) - } - }() - - return "socks5://" + ln.Addr().String(), func() { - _ = ln.Close() - <-done - } -} - -func handleBridgeSOCKS5Conn(conn net.Conn) { - header := make([]byte, 2) - if _, err := io.ReadFull(conn, header); err != nil { - _ = conn.Close() - return - } - methods := make([]byte, int(header[1])) - if _, err := io.ReadFull(conn, methods); err != nil { - _ = conn.Close() - return - } - _, _ = conn.Write([]byte{0x05, 0x00}) - - reqHeader := make([]byte, 4) - if _, err := io.ReadFull(conn, reqHeader); err != nil { - _ = conn.Close() - return - } - if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 { - _ = conn.Close() - return - } - - targetHost, ok := readSOCKS5Addr(conn, reqHeader[3]) - if !ok { - _ = conn.Close() - return - } - portBuf := make([]byte, 2) - if _, err := io.ReadFull(conn, portBuf); err != nil { - _ = conn.Close() - return - } - targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf)) - - targetConn, err := net.Dial("tcp", targetAddr) - if err != nil { - _, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - _ = conn.Close() - return - } - - _, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) - tunnelConns(conn, targetConn) -} - -func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) { - switch atyp { - case 0x01: - buf := make([]byte, 4) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", false - } - return net.IP(buf).String(), true - case 0x03: - lenBuf := make([]byte, 1) - if _, err := io.ReadFull(conn, lenBuf); err != nil { - return "", false - } - buf := make([]byte, int(lenBuf[0])) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", false - } - return string(buf), true - case 0x04: - buf := make([]byte, 16) - if _, err := io.ReadFull(conn, buf); err != nil { - return "", false - } - return net.IP(buf).String(), true - default: - return "", false - } -} diff --git a/backend/internal/pkg/lspool/proxy_exec.go b/backend/internal/pkg/lspool/proxy_exec.go deleted file mode 100644 index 38a678df..00000000 --- a/backend/internal/pkg/lspool/proxy_exec.go +++ /dev/null @@ -1,138 +0,0 @@ -package lspool - -import ( - "fmt" - "net/url" - "os" - "os/exec" - "strings" - - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" -) - -type lsLaunchPlan struct { - cmd *exec.Cmd - effectiveProxyURL string - proxyMode string - cleanup func() -} - -func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) { - normalized, parsed, err := proxyurl.Parse(rawProxyURL) - if err != nil { - return nil, err - } - - plan := &lsLaunchPlan{ - cmd: exec.Command(binPath, args...), - proxyMode: "direct", - } - - if parsed == nil { - return plan, nil - } - - switch strings.ToLower(parsed.Scheme) { - case "http", "https": - plan.effectiveProxyURL = normalized - plan.proxyMode = "env-http-proxy" - return plan, nil - - case "socks5", "socks5h": - if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil { - cfgPath, err := writeProxychainsConfig(parsed) - if err != nil { - return nil, err - } - plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...) - plan.proxyMode = "proxychains4" - plan.cleanup = func() { - _ = os.Remove(cfgPath) - } - return plan, nil - } - - effectiveProxyURL, err := prepareLSProxyURL(normalized) - if err != nil { - return nil, err - } - plan.effectiveProxyURL = effectiveProxyURL - plan.proxyMode = "http-connect-bridge" - return plan, nil - - default: - return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme) - } -} - -func writeProxychainsConfig(proxyURL *url.URL) (string, error) { - content, err := buildProxychainsConfig(proxyURL) - if err != nil { - return "", err - } - - file, err := os.CreateTemp("", "sub2api-proxychains-*.conf") - if err != nil { - return "", fmt.Errorf("create proxychains config: %w", err) - } - - if _, err := file.WriteString(content); err != nil { - _ = file.Close() - _ = os.Remove(file.Name()) - return "", fmt.Errorf("write proxychains config: %w", err) - } - if err := file.Close(); err != nil { - _ = os.Remove(file.Name()) - return "", fmt.Errorf("close proxychains config: %w", err) - } - - return file.Name(), nil -} - -func buildProxychainsConfig(proxyURL *url.URL) (string, error) { - if proxyURL == nil { - return "", fmt.Errorf("proxy url is nil") - } - if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" { - return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme) - } - - host := strings.TrimSpace(proxyURL.Hostname()) - port := strings.TrimSpace(proxyURL.Port()) - if host == "" { - return "", fmt.Errorf("proxy host is empty") - } - if port == "" { - port = "1080" - } - - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") { - return "", fmt.Errorf("proxychains credentials cannot contain whitespace") - } - - var builder strings.Builder - builder.WriteString("strict_chain\n") - builder.WriteString("proxy_dns\n") - builder.WriteString("remote_dns_subnet 224\n") - builder.WriteString("tcp_connect_time_out 8000\n") - builder.WriteString("tcp_read_time_out 15000\n") - builder.WriteString("localnet 127.0.0.0/255.0.0.0\n") - builder.WriteString("localnet ::1/128\n") - builder.WriteString("[ProxyList]\n") - builder.WriteString("socks5 ") - builder.WriteString(host) - builder.WriteString(" ") - builder.WriteString(port) - if username != "" { - builder.WriteString(" ") - builder.WriteString(username) - if password != "" { - builder.WriteString(" ") - builder.WriteString(password) - } - } - builder.WriteString("\n") - return builder.String(), nil -} diff --git a/backend/internal/pkg/lspool/proxy_exec_test.go b/backend/internal/pkg/lspool/proxy_exec_test.go deleted file mode 100644 index 44728b03..00000000 --- a/backend/internal/pkg/lspool/proxy_exec_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package lspool - -import ( - "net/url" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) { - proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080") - require.NoError(t, err) - - cfg, err := buildProxychainsConfig(proxyURL) - require.NoError(t, err) - require.Contains(t, cfg, "proxy_dns\n") - require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n") - require.Contains(t, cfg, "localnet ::1/128\n") - require.Contains(t, cfg, "[ProxyList]\n") - require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n") -} - -func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) { - proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080") - require.NoError(t, err) - - _, err = buildProxychainsConfig(proxyURL) - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "whitespace")) -} diff --git a/backend/internal/pkg/lspool/remote_instance.go b/backend/internal/pkg/lspool/remote_instance.go deleted file mode 100644 index 5de4bde2..00000000 --- a/backend/internal/pkg/lspool/remote_instance.go +++ /dev/null @@ -1,99 +0,0 @@ -package lspool - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" -) - -func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) { - endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("X-Worker-Token", i.workerToken) - if mode == "json" { - req.Header.Set("Content-Type", "application/json") - } else { - req.Header.Set("Content-Type", "application/octet-stream") - } - - resp, err := i.client.Do(req) - if err != nil { - return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("worker rpc read response: %w", err) - } - if resp.StatusCode != http.StatusOK { - return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200)) - } - return respBody, nil -} - -func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) { - endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("X-Worker-Token", i.workerToken) - if mode == "json" { - req.Header.Set("Content-Type", "application/json") - } else { - req.Header.Set("Content-Type", "application/octet-stream") - } - - resp, err := i.client.Do(req) - if err != nil { - return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err) - } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200)) - } - return resp, nil -} - -func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) { - base := url.URL{ - Scheme: "http", - Host: i.Address, - Path: path, - } - values := url.Values{} - values.Set("service", service) - values.Set("method", method) - values.Set("mode", mode) - if i.routingKey != "" { - values.Set("routing_key", i.routingKey) - } - base.RawQuery = values.Encode() - return base.String(), nil -} - -func marshalWorkerJSONBody(input any) ([]byte, error) { - if input == nil { - return []byte("{}"), nil - } - body, err := json.Marshal(input) - if err != nil { - return nil, err - } - return body, nil -} diff --git a/backend/internal/pkg/lspool/upstream_adapter.go b/backend/internal/pkg/lspool/upstream_adapter.go deleted file mode 100644 index 30a505c3..00000000 --- a/backend/internal/pkg/lspool/upstream_adapter.go +++ /dev/null @@ -1,1682 +0,0 @@ -// Package lspool provides an HTTPUpstream adapter that routes -// streamGenerateContent requests through real Language Server instances. -// -// Flow: -// -// sub2api → LSPoolUpstream.Do() → StartCascade → SendUserCascadeMessage -// → LS internally calls cloudcode-pa (with authentic TLS fingerprint) -// → Poll GetCascadeTrajectory for incremental text -// → Format as SSE and stream back to sub2api service layer -// -// The model is extracted from the original request body, not hardcoded. -package lspool - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" -) - -// Upstream is the interface matching service.HTTPUpstream -type Upstream interface { - Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) - DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) -} - -// LSPoolUpstream wraps an existing HTTPUpstream and intercepts -// streamGenerateContent requests to route them through the LS pool. -type LSPoolUpstream struct { - pool Backend - fallback Upstream - logger *slog.Logger - sessionMu sync.Mutex - sessions map[string]*cascadeSessionState -} - -// NewLSPoolUpstream creates an LS pool upstream wrapper. -func NewLSPoolUpstream(pool Backend, fallback Upstream) *LSPoolUpstream { - return &LSPoolUpstream{ - pool: pool, - fallback: fallback, - logger: slog.Default().With("component", "lspool-upstream"), - sessions: make(map[string]*cascadeSessionState), - } -} - -const ( - userNamespaceHeader = "X-Sub2API-User-Key" - useAICreditsHeader = "X-Antigravity-Use-AI-Credits" - availableCreditsHeader = "X-Antigravity-Available-Credits" - minimumCreditAmountHeader = "X-Antigravity-Minimum-Credit-Amount" - sessionStateTTL = 30 * time.Minute - lsSendMessageTimeout = 20 * time.Second - lsModelConfigTimeout = 20 * time.Second -) - -var ( - errLSRouteDirect = errors.New("request should use direct upstream") - errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session") - errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted") - errLSModelMapPending = errors.New("model mapping not ready") - errLSModelMapDenied = errors.New("model mapping unavailable") -) - -// IsLSQuotaExhaustedError reports whether err originated from an LS cascade -// quota/capacity exhaustion signal. -func IsLSQuotaExhaustedError(err error) bool { - return errors.Is(err, errLSQuotaExhausted) -} - -// LSQuotaExhaustedMessage extracts the original LS error message, if present. -func LSQuotaExhaustedMessage(err error) string { - if err == nil { - return "" - } - msg := strings.TrimSpace(err.Error()) - if msg == "" { - return "" - } - prefix := errLSQuotaExhausted.Error() - if msg == prefix { - return "" - } - if strings.HasPrefix(msg, prefix+":") { - return strings.TrimSpace(strings.TrimPrefix(msg, prefix+":")) - } - return msg -} - -func isPermanentModelMappingError(err error) bool { - if err == nil { - return false - } - return strings.Contains(strings.ToLower(err.Error()), "unauthorized_client") -} - -func modelMappingDeniedReason(err error) string { - if err == nil { - return "" - } - return truncate(strings.TrimSpace(err.Error()), 200) -} - -type cascadeSessionState struct { - CascadeID string - SystemText string - History []geminiConversationTurn - UpdatedAt time.Time -} - -type geminiEnvelope struct { - Model string `json:"model"` - Request json.RawMessage `json:"request"` -} - -type geminiRequestPayload struct { - Contents []geminiWireContent `json:"contents"` - SystemInstruction *geminiWireContent `json:"systemInstruction,omitempty"` - GenerationConfig *geminiWireGenerationConfig `json:"generationConfig,omitempty"` - SessionID string `json:"sessionId,omitempty"` -} - -type geminiWireGenerationConfig struct { - ResponseModalities []string `json:"responseModalities,omitempty"` - ImageConfig json.RawMessage `json:"imageConfig,omitempty"` -} - -type geminiWireContent struct { - Role string `json:"role"` - Parts []geminiWirePart `json:"parts"` -} - -type geminiWirePart struct { - Text string `json:"text,omitempty"` - Thought bool `json:"thought,omitempty"` - ThoughtSignature string `json:"thoughtSignature,omitempty"` - InlineData *geminiWireInlineData `json:"inlineData,omitempty"` - FunctionCall map[string]any `json:"functionCall,omitempty"` - FunctionResponse map[string]any `json:"functionResponse,omitempty"` -} - -type geminiWireInlineData struct { - MimeType string `json:"mimeType"` - Data string `json:"data"` -} - -type geminiParsedRequest struct { - Model string - SessionID string - SystemText string - Turns []geminiConversationTurn - ResponseModalities []string - HasImageConfig bool - HasUnsupported bool -} - -type geminiConversationTurn struct { - Role string - Parts []geminiConversationPart -} - -type geminiConversationPart struct { - Kind string - Text string - MimeType string - Data string -} - -type lsRouteDecision struct { - UseLS bool - Reason string -} - -type lsRequestTrace struct { - StartedAt time.Time - AccountID int64 - Model string - SessionIDHash string - Replica int - CascadeID string - NewSession bool - InflightAtAcquire int64 - TurnCount int - GetOrCreateDuration time.Duration - StartCascadeDuration time.Duration - BuildInputDuration time.Duration - SendMessageDuration time.Duration - FirstPollLatency time.Duration - FirstTextLatency time.Duration - PollCount int -} - -// Do routes streamGenerateContent through LS, everything else through fallback. -func (u *LSPoolUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { - u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) - - if !isStreamGenerate(req.URL.Path) { - return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) - } - body, err := snapshotRequestBody(req) - if err != nil { - return nil, fmt.Errorf("snapshot request body: %w", err) - } - if len(bytes.TrimSpace(body)) == 0 { - return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) - } - - resp, err := u.doViaLS(req, body, accountID, proxyURL) - if err != nil { - if shouldFallbackDirect(err) { - u.logger.Warn("[LS-POOL] LS fell back to direct", "account", accountID, "err", err) - req.Body = io.NopCloser(bytes.NewReader(body)) - return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) - } - return nil, err - } - return resp, nil -} - -// DoWithTLS — LS handles its own TLS, so profile is ignored for LS requests. -func (u *LSPoolUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { - u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) - - if !isStreamGenerate(req.URL.Path) { - return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) - } - body, err := snapshotRequestBody(req) - if err != nil { - return nil, fmt.Errorf("snapshot request body: %w", err) - } - if len(bytes.TrimSpace(body)) == 0 { - return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) - } - - resp, err := u.doViaLS(req, body, accountID, proxyURL) - if err != nil { - if shouldFallbackDirect(err) { - u.logger.Warn("[LS-POOL] LS fell back to direct+TLS", "account", accountID, "err", err) - req.Body = io.NopCloser(bytes.NewReader(body)) - return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) - } - return nil, err - } - return resp, nil -} - -func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64, proxyURL string) (*http.Response, error) { - accountKey := strconv.FormatInt(accountID, 10) - - if CurrentLSStrategy() != LSStrategyJSParity { - return u.forwardDirectWithKeepalive(req, body, accountKey, accountID, proxyURL) - } - - parsed, err := parseGeminiRequest(body) - if err != nil { - return u.forwardDirect(req, body, proxyURL, accountID, "parse request failed") - } - - decision := decideJSParityRoute(parsed, body) - if !decision.UseLS { - return u.forwardDirect(req, body, proxyURL, accountID, decision.Reason) - } - - resp, err := u.forwardChatViaLS(req, body, parsed, accountKey, accountID, proxyURL) - if err != nil { - if shouldFallbackDirect(err) { - return u.forwardDirect(req, body, proxyURL, accountID, err.Error()) - } - return nil, err - } - return resp, nil -} - -func shouldFallbackDirect(err error) bool { - return errors.Is(err, errLSRouteDirect) || - errors.Is(err, errLSTranscriptDrift) || - errors.Is(err, errLSModelMapDenied) -} - -func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { - // Start/reuse LS instance — keeps heartbeat alive, authenticates with - // cloudcode-pa, and refreshes model mapping. The LS process itself is NOT - // used as a proxy; we forward the original HTTP request directly to - // cloudcode-pa, bypassing Cascade entirely. This avoids the IDE agent - // system prompt that Cascade injects. - _, err := u.pool.GetOrCreate(accountKey, "", proxyURL) - if err != nil { - return nil, fmt.Errorf("get LS instance: %w", err) - } - - u.logger.Info("[LS-POOL] Forwarding via direct HTTP (LS keepalive active)", - "account", accountID, "path", req.URL.Path) - - return u.forwardDirect(req, body, proxyURL, accountID, "strategy=direct") -} - -func (u *LSPoolUpstream) forwardDirect(req *http.Request, body []byte, proxyURL string, accountID int64, reason string) (*http.Response, error) { - u.logger.Info("[LS-POOL] Forwarding via direct HTTP", - "account", accountID, - "path", req.URL.Path, - "reason", reason) - req.Header.Del(userNamespaceHeader) - req.Body = io.NopCloser(bytes.NewReader(body)) - return u.fallback.Do(req, proxyURL, accountID, 1) -} - -func (u *LSPoolUpstream) extractAndStripInternalHeaders(req *http.Request, accountKey string) { - if auth := req.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { - accessToken := strings.TrimPrefix(auth, "Bearer ") - refreshToken := req.Header.Get("X-Antigravity-Refresh-Token") - var expiresAt time.Time - if raw := req.Header.Get("X-Antigravity-Token-Expiry"); raw != "" { - if parsed, err := time.Parse(time.RFC3339, raw); err == nil { - expiresAt = parsed - } - } - u.pool.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt) - } - - useAICredits, hasUseAICredits := parseBoolHeader(req.Header.Get(useAICreditsHeader)) - availableCredits, hasAvailableCredits := parseOptionalInt32Header(req.Header.Get(availableCreditsHeader)) - minimumCreditAmount, hasMinimumCreditAmount := parseOptionalInt32Header(req.Header.Get(minimumCreditAmountHeader)) - if hasUseAICredits || hasAvailableCredits || hasMinimumCreditAmount { - u.pool.SetAccountModelCredits(accountKey, useAICredits, availableCredits, minimumCreditAmount) - } - - req.Header.Del("X-Antigravity-Refresh-Token") - req.Header.Del("X-Antigravity-Token-Expiry") - req.Header.Del(useAICreditsHeader) - req.Header.Del(availableCreditsHeader) - req.Header.Del(minimumCreditAmountHeader) -} - -func parseBoolHeader(raw string) (bool, bool) { - raw = strings.TrimSpace(raw) - if raw == "" { - return false, false - } - val, err := strconv.ParseBool(raw) - if err != nil { - return false, false - } - return val, true -} - -func parseOptionalInt32Header(raw string) (*int32, bool) { - raw = strings.TrimSpace(raw) - if raw == "" { - return nil, false - } - val, err := strconv.ParseInt(raw, 10, 32) - if err != nil { - return nil, false - } - parsed := int32(val) - return &parsed, true -} - -func shortTraceID(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "none" - } - sum := sha256.Sum256([]byte(raw)) - return fmt.Sprintf("%x", sum[:4]) -} - -func durationMS(d time.Duration) int64 { - if d <= 0 { - return 0 - } - return d.Milliseconds() -} - -func (u *LSPoolUpstream) logTraceSummary(level slog.Level, msg string, trace *lsRequestTrace, extra ...any) { - if trace == nil { - u.logger.Log(context.Background(), level, msg, extra...) - return - } - args := []any{ - "account", trace.AccountID, - "model", trace.Model, - "session", trace.SessionIDHash, - "replica", trace.Replica, - "cascade", shortTraceID(trace.CascadeID), - "new_session", trace.NewSession, - "turns", trace.TurnCount, - "inflight", trace.InflightAtAcquire, - "get_or_create_ms", durationMS(trace.GetOrCreateDuration), - "start_cascade_ms", durationMS(trace.StartCascadeDuration), - "build_input_ms", durationMS(trace.BuildInputDuration), - "send_message_ms", durationMS(trace.SendMessageDuration), - "first_poll_ms", durationMS(trace.FirstPollLatency), - "first_token_ms", durationMS(trace.FirstTextLatency), - "polls", trace.PollCount, - "total_ms", durationMS(time.Since(trace.StartedAt)), - } - args = append(args, extra...) - u.logger.Log(context.Background(), level, msg, args...) -} - -func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed *geminiParsedRequest, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { - trace := &lsRequestTrace{ - StartedAt: time.Now(), - AccountID: accountID, - Model: parsed.Model, - SessionIDHash: shortTraceID(parsed.SessionID), - TurnCount: len(parsed.Turns), - } - - getOrCreateStartedAt := time.Now() - sessionKey := buildSessionCacheKey(accountID, userNamespace(req), parsed.SessionID) - inst, err := u.pool.GetOrCreate(accountKey, sessionKey, proxyURL) - if err != nil { - trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] get instance failed", trace, "err", err) - return nil, fmt.Errorf("get LS instance: %w", err) - } - trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) - trace.Replica = inst.Replica - if inst.HasModelMappingUnavailable() { - reason := inst.ModelMappingUnavailableReason() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping unavailable, routing direct", trace, "reason", reason) - return nil, fmt.Errorf("%w: %s", errLSModelMapDenied, reason) - } - if !inst.HasModelMappingReady() { - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace) - return nil, errLSModelMapPending - } - if !inst.AcquireConcurrency() { - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] instance busy", trace, - "err", fmt.Sprintf("ls instance busy for account %d", accountID), - "current_inflight", inst.ConcurrentCount(), - "max_inflight", maxConcurrencyPerInstance) - return nil, fmt.Errorf("ls instance busy for account %d", accountID) - } - trace.InflightAtAcquire = inst.ConcurrentCount() - - state := u.getSessionState(sessionKey) - if state != nil && !systemTextCompatible(state.SystemText, parsed.SystemText) { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript drift, routing direct", trace) - return nil, errLSTranscriptDrift - } - - cascadeID := "" - newSession := false - sendTurn := geminiConversationTurn{} - contextPrefix := "" - - switch { - case state == nil: - if len(parsed.Turns) == 0 { - inst.ReleaseConcurrency() - return nil, errLSRouteDirect - } - lastTurn := parsed.Turns[len(parsed.Turns)-1] - if lastTurn.Role != "user" { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] invalid first turn for LS, routing direct", trace) - return nil, errLSRouteDirect - } - sendTurn = lastTurn - contextPrefix = renderConversationContext(parsed.SystemText, parsed.Turns[:len(parsed.Turns)-1]) - startCascadeStartedAt := time.Now() - cascadeID, err = u.startCascade(inst) - trace.StartCascadeDuration = time.Since(startCascadeStartedAt) - if err != nil { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] start cascade failed", trace, "err", err) - return nil, err - } - newSession = true - case !conversationPrefixEqual(parsed.Turns, state.History): - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript prefix mismatch, routing direct", trace) - return nil, errLSTranscriptDrift - default: - delta := parsed.Turns[len(state.History):] - if len(delta) != 1 || delta[0].Role != "user" { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] unsupported transcript delta, routing direct", trace) - return nil, errLSRouteDirect - } - sendTurn = delta[0] - cascadeID = state.CascadeID - } - trace.NewSession = newSession - trace.CascadeID = cascadeID - - buildInputStartedAt := time.Now() - items, media, err := buildLSInputFromTurn(sendTurn, contextPrefix) - trace.BuildInputDuration = time.Since(buildInputStartedAt) - if err != nil { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] build input failed", trace, "err", err) - return nil, fmt.Errorf("build ls input: %w", err) - } - if len(items) == 0 && len(media) == 0 { - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] empty LS input, routing direct", trace) - return nil, errLSRouteDirect - } - sendReq := map[string]any{ - "metadata": buildLSRequestMetadata(), - "cascadeId": cascadeID, - "items": items, - "blocking": false, - } - if len(media) > 0 { - sendReq["media"] = media - } - if cfg := buildCascadeConfig(parsed.Model); cfg != nil { - sendReq["cascadeConfig"] = cfg - } - sendStartedAt := time.Now() - sendCtx, sendCancel := context.WithTimeout(req.Context(), lsSendMessageTimeout) - defer sendCancel() - if _, err := inst.CallUnaryJSON(sendCtx, LSService, "SendUserCascadeMessage", sendReq); err != nil { - trace.SendMessageDuration = time.Since(sendStartedAt) - if newSession { - u.cancelCascade(inst, cascadeID) - } - inst.ReleaseConcurrency() - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] send user message failed", trace, "err", err) - return nil, fmt.Errorf("send user cascade message: %w", err) - } - trace.SendMessageDuration = time.Since(sendStartedAt) - - pr, pw := io.Pipe() - resp := &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{ - "Content-Type": []string{"text/event-stream"}, - "Cache-Control": []string{"no-cache"}, - "X-Accel-Buffering": []string{"no"}, - }, - Body: pr, - Request: req, - } - - go func() { - defer inst.ReleaseConcurrency() - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - - u.streamCascadeResponse(ctx, inst, cascadeID, pw, trace, func(finalText string) { - u.putSessionState(sessionKey, &cascadeSessionState{ - CascadeID: cascadeID, - SystemText: parsed.SystemText, - History: appendModelTurn(cloneConversationTurns(parsed.Turns), finalText), - UpdatedAt: time.Now(), - }) - }) - }() - - return resp, nil -} - -func (u *LSPoolUpstream) startCascade(inst *Instance) (string, error) { - resp, err := inst.CallUnaryJSON(context.Background(), LSService, "StartCascade", map[string]any{ - "metadata": buildLSRequestMetadata(), - }) - if err != nil { - return "", fmt.Errorf("start cascade: %w", err) - } - var decoded struct { - CascadeID string `json:"cascadeId"` - } - if err := json.Unmarshal(resp, &decoded); err != nil { - return "", fmt.Errorf("decode start cascade: %w", err) - } - if decoded.CascadeID == "" { - return "", errors.New("start cascade returned empty cascadeId") - } - return decoded.CascadeID, nil -} - -// cancelCascade tells the LS to stop processing a cascade invocation. -// Uses a short timeout — best-effort, don't block shutdown. -func (u *LSPoolUpstream) cancelCascade(inst *Instance, cascadeID string) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - _, err := inst.CallUnaryJSON(ctx, LSService, "CancelCascadeInvocation", map[string]any{ - "cascadeId": cascadeID, - }) - if err != nil { - // Try force stop as fallback - _, _ = inst.CallUnaryJSON(ctx, LSService, "ForceStopCascadeTree", map[string]any{ - "cascadeId": cascadeID, - }) - } -} - -// streamCascadeResponse polls GetCascadeTrajectory with adaptive interval. -// Fast (50ms) when model is generating, slow (150ms) when waiting. -// We also issue an immediate first poll so the first token is not delayed by -// the initial ticker interval. -func (u *LSPoolUpstream) streamCascadeResponse(ctx context.Context, inst *Instance, cascadeID string, w *io.PipeWriter, trace *lsRequestTrace, onDone func(string)) { - const ( - fastInterval = 50 * time.Millisecond - slowInterval = 150 * time.Millisecond - maxDuration = 5 * time.Minute - maxIdleTimeout = 30 * time.Second - ) - - ticker := time.NewTicker(slowInterval) - defer ticker.Stop() - - timeout := time.After(maxDuration) - lastText := "" - generating := false - lastProgressAt := time.Time{} - - pollOnce := func() bool { - if trace != nil { - trace.PollCount++ - } - trajResp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeTrajectory", map[string]any{ - "cascadeId": cascadeID, - }) - if err != nil { - if ctx.Err() != nil { - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) - _ = w.Close() - return true - } - return false - } - if trace != nil && trace.FirstPollLatency == 0 { - trace.FirstPollLatency = time.Since(trace.StartedAt) - } - - state := extractPlannerResponseState(trajResp) - text, isGenerating, status := state.Text, state.Generating, state.Status - if state.ErrorMessage != "" { - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade terminated with error", trace, "error", state.ErrorMessage) - if isQuotaExhaustedError(state.ErrorMessage) { - _ = w.CloseWithError(fmt.Errorf("%w: %s", errLSQuotaExhausted, state.ErrorMessage)) - } else { - _ = w.CloseWithError(errors.New(state.ErrorMessage)) - } - return true - } - - // Adaptive interval: fast when generating, slow when idle. - if isGenerating && !generating { - ticker.Reset(fastInterval) - generating = true - } else if !isGenerating && generating { - ticker.Reset(slowInterval) - generating = false - } - - // Emit new text as SSE. - if text != lastText && len(text) > len(lastText) { - newPart := text[len(lastText):] - sseEvent := buildGeminiSSEChunk(newPart) - if _, err := w.Write([]byte(sseEvent)); err != nil { - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write SSE failed", trace, "err", err) - _ = w.CloseWithError(err) - return true - } - lastText = text - lastProgressAt = time.Now() - if trace != nil && trace.FirstTextLatency == 0 { - trace.FirstTextLatency = time.Since(trace.StartedAt) - } - } - - // Check if done. - if status == "CASCADE_RUN_STATUS_IDLE" && text != "" && !isGenerating { - usage := extractUsageFromTrajectory(trajResp) - if usage != nil { - finalEvent := buildGeminiSSEFinalChunk(usage) - if _, err := w.Write([]byte(finalEvent)); err != nil { - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write final SSE failed", trace, "err", err) - _ = w.CloseWithError(err) - return true - } - } - if onDone != nil { - onDone(lastText) - } - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request completed", trace) - _ = w.Close() - return true - } - - if !lastProgressAt.IsZero() && time.Since(lastProgressAt) > maxIdleTimeout { - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] No progress, stopping", trace) - _ = w.Close() - return true - } - - return false - } - - if pollOnce() { - return - } - - for { - select { - case <-ctx.Done(): - u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) - _ = w.Close() - return - case <-timeout: - u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade timeout", trace) - _ = w.Close() - return - case <-ticker.C: - if pollOnce() { - return - } - } - } -} - -// ============================================================ -// SSE builders — match Gemini v1internal:streamGenerateContent?alt=sse format -// ============================================================ - -func buildGeminiSSEChunk(text string) string { - // cloudcode-pa v1internal 格式: {"response": {"candidates": [...]}} - chunk := map[string]any{ - "response": map[string]any{ - "candidates": []map[string]any{ - { - "content": map[string]any{ - "parts": []map[string]string{{"text": text}}, - "role": "model", - }, - }, - }, - }, - } - data, _ := json.Marshal(chunk) - return "data: " + string(data) + "\n\n" -} - -func buildGeminiSSEFinalChunk(usage map[string]any) string { - chunk := map[string]any{ - "response": map[string]any{ - "candidates": []map[string]any{ - { - "content": map[string]any{ - "parts": []map[string]string{{"text": ""}}, - "role": "model", - }, - "finishReason": "STOP", - }, - }, - "usageMetadata": usage, - }, - } - data, _ := json.Marshal(chunk) - return "data: " + string(data) + "\n\n" -} - -// ============================================================ -// Trajectory parsing -// ============================================================ - -type cascadePlannerState struct { - Text string - Generating bool - Status string - ErrorMessage string -} - -func extractPlannerResponseState(trajResp []byte) cascadePlannerState { - var raw map[string]any - if err := json.Unmarshal(trajResp, &raw); err != nil { - return cascadePlannerState{} - } - state := cascadePlannerState{} - state.Status, _ = raw["status"].(string) - state.ErrorMessage = findCascadeErrorMessage(raw) - - traj, ok := raw["trajectory"].(map[string]any) - if !ok { - return state - } - steps, ok := traj["steps"].([]any) - if !ok { - return state - } - for _, s := range steps { - sm, ok := s.(map[string]any) - if !ok { - continue - } - if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { - continue - } - if sm["status"] == "CORTEX_STEP_STATUS_GENERATING" { - state.Generating = true - } - if pr, ok := sm["plannerResponse"].(map[string]any); ok { - if r, ok := pr["response"].(string); ok { - state.Text = r - } - } - } - return state -} - -func extractPlannerResponseText(trajResp []byte) (text string, generating bool, status string) { - state := extractPlannerResponseState(trajResp) - return state.Text, state.Generating, state.Status -} - -func findCascadeErrorMessage(value any) string { - switch v := value.(type) { - case map[string]any: - if msg := summarizeCascadeErrorMap(v); msg != "" { - return msg - } - for _, child := range v { - if msg := findCascadeErrorMessage(child); msg != "" { - return msg - } - } - case []any: - for _, child := range v { - if msg := findCascadeErrorMessage(child); msg != "" { - return msg - } - } - } - return "" -} - -func summarizeCascadeErrorMap(m map[string]any) string { - if full := cascadeStringField(m, "fullError"); full != "" { - return full - } - if user := cascadeStringField(m, "userErrorMessage"); user != "" { - return user - } - - _, hasErrorCode := m["errorCode"] - short := cascadeStringField(m, "shortError") - details := cascadeStringField(m, "details") - message := cascadeStringField(m, "message") - - if hasErrorCode { - parts := make([]string, 0, 3) - if short != "" { - parts = append(parts, short) - } - if message != "" && message != short { - parts = append(parts, message) - } - if details != "" && details != short && details != message { - parts = append(parts, details) - } - if len(parts) > 0 { - return strings.Join(parts, ": ") - } - return fmt.Sprintf("cascade error: %v", m["errorCode"]) - } - - if reason := cascadeStringField(m, "terminationReason"); strings.Contains(strings.ToUpper(reason), "ERROR") { - if message != "" { - return message - } - if short != "" { - return short - } - } - - return "" -} - -func cascadeStringField(m map[string]any, key string) string { - raw, ok := m[key] - if !ok { - return "" - } - str, ok := raw.(string) - if !ok { - return "" - } - return strings.TrimSpace(str) -} - -func extractUsageFromTrajectory(trajResp []byte) map[string]any { - var raw map[string]any - if err := json.Unmarshal(trajResp, &raw); err != nil { - return nil - } - traj, ok := raw["trajectory"].(map[string]any) - if !ok { - return nil - } - steps, ok := traj["steps"].([]any) - if !ok { - return nil - } - for _, s := range steps { - sm, ok := s.(map[string]any) - if !ok { - continue - } - if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { - continue - } - meta, ok := sm["metadata"].(map[string]any) - if !ok { - continue - } - mu, ok := meta["modelUsage"].(map[string]any) - if !ok { - continue - } - input, _ := mu["inputTokens"].(string) - output, _ := mu["outputTokens"].(string) - inputN, _ := strconv.Atoi(input) - outputN, _ := strconv.Atoi(output) - if inputN > 0 || outputN > 0 { - return map[string]any{ - "promptTokenCount": inputN, - "candidatesTokenCount": outputN, - "totalTokenCount": inputN + outputN, - } - } - } - return nil -} - -// ============================================================ -// Request parsing — dynamic model, no hardcoding -// ============================================================ - -func parseGeminiRequest(body []byte) (*geminiParsedRequest, error) { - var envelope geminiEnvelope - if err := json.Unmarshal(body, &envelope); err != nil { - return nil, err - } - - reqBody := body - if len(envelope.Request) > 0 { - reqBody = envelope.Request - } - - var payload geminiRequestPayload - if err := json.Unmarshal(reqBody, &payload); err != nil { - return nil, err - } - - parsed := &geminiParsedRequest{ - Model: envelope.Model, - SessionID: payload.SessionID, - ResponseModalities: append([]string(nil), payload.GenerationConfig.GetResponseModalities()...), - HasImageConfig: payload.GenerationConfig != nil && len(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) > 0 && string(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) != "null", - } - if parsed.Model == "" { - var top map[string]json.RawMessage - if err := json.Unmarshal(body, &top); err == nil { - _ = json.Unmarshal(top["model"], &parsed.Model) - } - } - if payload.SystemInstruction != nil { - parsed.SystemText = collectTextParts(payload.SystemInstruction.Parts) - } - for _, content := range payload.Contents { - turn := geminiConversationTurn{Role: normalizeTurnRole(content.Role)} - for _, part := range content.Parts { - switch { - case part.Thought || part.ThoughtSignature != "": - parsed.HasUnsupported = true - case len(part.FunctionCall) > 0 || len(part.FunctionResponse) > 0: - parsed.HasUnsupported = true - case part.InlineData != nil: - turn.Parts = append(turn.Parts, geminiConversationPart{ - Kind: "media", - MimeType: part.InlineData.MimeType, - Data: part.InlineData.Data, - }) - case part.Text != "": - turn.Parts = append(turn.Parts, geminiConversationPart{ - Kind: "text", - Text: part.Text, - }) - } - } - if len(turn.Parts) > 0 { - parsed.Turns = append(parsed.Turns, turn) - } - } - - return parsed, nil -} - -func collectTextParts(parts []geminiWirePart) string { - var texts []string - for _, part := range parts { - if part.Text != "" { - texts = append(texts, part.Text) - } - } - return strings.Join(texts, "\n") -} - -func (g *geminiWireGenerationConfig) GetResponseModalities() []string { - if g == nil { - return nil - } - return g.ResponseModalities -} - -func normalizeTurnRole(role string) string { - if strings.EqualFold(strings.TrimSpace(role), "model") { - return "model" - } - return "user" -} - -func decideJSParityRoute(parsed *geminiParsedRequest, body []byte) lsRouteDecision { - if parsed == nil { - return lsRouteDecision{Reason: "nil parsed request"} - } - if requestHasTools(body) { - return lsRouteDecision{Reason: "tools are not supported through cascade"} - } - if parsed.SessionID == "" { - return lsRouteDecision{Reason: "missing sessionId"} - } - if parsed.HasUnsupported { - return lsRouteDecision{Reason: "request contains unsupported Gemini parts"} - } - if isImageGenerationModelName(parsed.Model) { - return lsRouteDecision{Reason: "image generation model"} - } - if parsed.HasImageConfig { - return lsRouteDecision{Reason: "request has imageConfig"} - } - for _, modality := range parsed.ResponseModalities { - if strings.EqualFold(strings.TrimSpace(modality), "IMAGE") { - return lsRouteDecision{Reason: "responseModalities contains IMAGE"} - } - } - if len(parsed.Turns) == 0 { - return lsRouteDecision{Reason: "empty conversation"} - } - return lsRouteDecision{UseLS: true, Reason: "js-parity cascade chat"} -} - -func extractPromptAndModel(body []byte) (string, string) { - var outer map[string]json.RawMessage - if err := json.Unmarshal(body, &outer); err != nil { - return "", "" - } - var model string - if m, ok := outer["model"]; ok { - json.Unmarshal(m, &model) - } - if reqRaw, ok := outer["request"]; ok { - return extractPromptFromGeminiRequest(reqRaw), model - } - return extractPromptFromGeminiRequest(body), model -} - -func extractPromptFromGeminiRequest(data []byte) string { - var req struct { - Contents []struct { - Parts []struct { - Text string `json:"text"` - } `json:"parts"` - Role string `json:"role"` - } `json:"contents"` - SystemInstruction *struct { - Parts []struct { - Text string `json:"text"` - } `json:"parts"` - } `json:"systemInstruction"` - } - if err := json.Unmarshal(data, &req); err != nil { - return "" - } - - var parts []string - - // Include system instruction if present - if req.SystemInstruction != nil { - for _, p := range req.SystemInstruction.Parts { - if p.Text != "" { - parts = append(parts, "[System]\n"+p.Text) - } - } - } - - // Include full conversation history - for _, c := range req.Contents { - role := c.Role - if role == "" { - role = "user" - } - for _, p := range c.Parts { - if p.Text != "" { - if role == "model" { - parts = append(parts, "[Assistant]\n"+p.Text) - } else { - parts = append(parts, "[User]\n"+p.Text) - } - } - } - } - - if len(parts) == 0 { - return "" - } - - // If only one part and no system instruction, return raw text (simple case) - if len(parts) == 1 && req.SystemInstruction == nil { - text := parts[0] - // Strip the [User]\n prefix for simple single-message case - if strings.HasPrefix(text, "[User]\n") { - return strings.TrimPrefix(text, "[User]\n") - } - return text - } - - return strings.Join(parts, "\n\n") -} - -func buildLSInputFromTurn(turn geminiConversationTurn, contextPrefix string) ([]map[string]any, []map[string]any, error) { - items := make([]map[string]any, 0, len(turn.Parts)+1) - media := make([]map[string]any, 0) - if strings.TrimSpace(contextPrefix) != "" { - items = append(items, map[string]any{"text": contextPrefix}) - } - for _, part := range turn.Parts { - switch part.Kind { - case "text": - if part.Text != "" { - items = append(items, map[string]any{"text": part.Text}) - } - case "media": - decoded, err := base64.StdEncoding.DecodeString(part.Data) - if err != nil { - return nil, nil, fmt.Errorf("decode inlineData: %w", err) - } - media = append(media, map[string]any{ - "mimeType": part.MimeType, - "inlineData": decoded, - }) - } - } - return items, media, nil -} - -func renderConversationContext(systemText string, turns []geminiConversationTurn) string { - var parts []string - if strings.TrimSpace(systemText) != "" { - parts = append(parts, "[System]\n"+strings.TrimSpace(systemText)) - } - for _, turn := range turns { - var rendered []string - for _, part := range turn.Parts { - switch part.Kind { - case "text": - if strings.TrimSpace(part.Text) != "" { - rendered = append(rendered, part.Text) - } - case "media": - label := "attachment" - switch { - case strings.HasPrefix(part.MimeType, "image/"): - label = "image attachment" - case strings.HasPrefix(part.MimeType, "video/"): - label = "video attachment" - case strings.HasPrefix(part.MimeType, "audio/"): - label = "audio attachment" - } - rendered = append(rendered, fmt.Sprintf("[%s: %s]", label, part.MimeType)) - } - } - if len(rendered) == 0 { - continue - } - roleLabel := "User" - if turn.Role == "model" { - roleLabel = "Assistant" - } - parts = append(parts, fmt.Sprintf("[%s]\n%s", roleLabel, strings.Join(rendered, "\n"))) - } - return strings.Join(parts, "\n\n") -} - -func buildCascadeConfig(model string) map[string]any { - normalizedModel := normalizeRequestedModelName(model) - if normalizedModel == "" { - return nil - } - modelEnum := resolveModelEnum(normalizedModel) - - return map[string]any{ - "plannerConfig": map[string]any{ - "requestedModel": map[string]any{ - "model": modelEnum, - }, - }, - } -} - -func buildLSRequestMetadata() map[string]any { - return map[string]any{ - "ideName": "antigravity", - "ideVersion": "1.107.0", - } -} - -func appendModelTurn(turns []geminiConversationTurn, modelText string) []geminiConversationTurn { - if strings.TrimSpace(modelText) == "" { - return turns - } - return append(turns, geminiConversationTurn{ - Role: "model", - Parts: []geminiConversationPart{{ - Kind: "text", - Text: modelText, - }}, - }) -} - -func cloneConversationTurns(src []geminiConversationTurn) []geminiConversationTurn { - out := make([]geminiConversationTurn, 0, len(src)) - for _, turn := range src { - copied := geminiConversationTurn{ - Role: turn.Role, - Parts: append([]geminiConversationPart(nil), turn.Parts...), - } - out = append(out, copied) - } - return out -} - -func conversationPrefixEqual(full, prefix []geminiConversationTurn) bool { - if len(prefix) > len(full) { - return false - } - for i := range prefix { - if prefix[i].Role != full[i].Role { - return false - } - if len(prefix[i].Parts) != len(full[i].Parts) { - return false - } - for j := range prefix[i].Parts { - if prefix[i].Parts[j] != full[i].Parts[j] { - return false - } - } - } - return true -} - -// ResolveModelEnumPublic is the exported version of resolveModelEnum for testing. -func ResolveModelEnumPublic(model string) int { - return resolveModelEnum(model) -} - -// resolveModelEnum maps a Gemini/Claude model name to its proto enum number. -// Priority: dynamic mapping (from LS) > static fallback. -// The LS uses MODEL_PLACEHOLDER_Mn enum values (1000+n) that are dynamically -// assigned by the server — only these are guaranteed to work. -func resolveModelEnum(model string) int { - model = normalizeRequestedModelName(model) - - // 1. Try dynamic mapping first (populated by RefreshModelMapping from LS) - dynamicModelMapMu.RLock() - // Exact match - if v, ok := dynamicModelMap[model]; ok { - dynamicModelMapMu.RUnlock() - return v - } - // Fuzzy match: normalized label vs model name - for label, v := range dynamicModelMap { - if labelMatchesModel(label, model) { - dynamicModelMapMu.RUnlock() - return v - } - } - // Prefix match in dynamic map - for label, v := range dynamicModelMap { - normalized := normalizeLabel(label) - if strings.HasPrefix(model, normalized) || strings.HasPrefix(normalized, model) { - dynamicModelMapMu.RUnlock() - return v - } - } - dynamicModelMapMu.RUnlock() - - // 2. Known working placeholders (verified on Mac with LS v1.107.0) - // These map display labels to MODEL_PLACEHOLDER_Mn enum values - knownPlaceholders := map[string]int{ - "gemini-3-flash": 1047, - "gemini-3.1-pro-high": 1037, - "gemini-3.1-pro-low": 1036, - "claude-sonnet-4-6-thinking": 1035, - "claude-opus-4-6-thinking": 1026, - "gpt-oss-120b-medium": 342, - } - if v, ok := knownPlaceholders[model]; ok { - return v - } - // Fuzzy match known placeholders - modelLower := strings.ToLower(model) - for k, v := range knownPlaceholders { - if strings.Contains(modelLower, strings.ToLower(k)) || strings.Contains(strings.ToLower(k), modelLower) { - return v - } - } - - // 3. Family-based fallback from known placeholders - for k, v := range knownPlaceholders { - if strings.Contains(modelLower, "claude") && strings.Contains(k, "claude") { - return v - } - if strings.Contains(modelLower, "gemini") && strings.Contains(k, "gemini") { - return v - } - if strings.Contains(modelLower, "gpt") && strings.Contains(k, "gpt") { - return v - } - } - - // 4. Also check dynamic map if available - dynamicModelMapMu.RLock() - defer dynamicModelMapMu.RUnlock() - for label, v := range dynamicModelMap { - labelLower := strings.ToLower(normalizeLabel(label)) - // Same family: "claude" matches "claude-*", "gemini" matches "gemini-*" - if strings.Contains(modelLower, "claude") && strings.Contains(labelLower, "claude") { - return v - } - if strings.Contains(modelLower, "gemini") && strings.Contains(labelLower, "gemini") { - return v - } - if strings.Contains(modelLower, "gpt") && strings.Contains(labelLower, "gpt") { - return v - } - } - - // Last resort: return first available model from dynamic map - for _, v := range dynamicModelMap { - return v - } - - // No dynamic mapping at all (LS not started yet?) — use gemini-2.5-flash static - return 312 -} - -// labelMatchesModel does fuzzy matching between LS display label and sub2api model name. -// e.g. "Gemini 3 Flash" matches "gemini-3-flash", "Claude Sonnet 4.6 (Thinking)" matches "claude-sonnet-4-6-thinking" -func labelMatchesModel(label, model string) bool { - normalize := func(s string) string { - s = strings.ToLower(s) - s = strings.ReplaceAll(s, " ", "-") - s = strings.ReplaceAll(s, ".", "-") - s = strings.ReplaceAll(s, "(", "") - s = strings.ReplaceAll(s, ")", "") - s = strings.ReplaceAll(s, "--", "-") - return strings.TrimRight(s, "-") - } - return normalize(label) == normalize(model) -} - -// Dynamic model mapping — refreshed from LS at startup -var ( - dynamicModelMapMu sync.RWMutex - dynamicModelMap = map[string]int{} // label -> enum value -) - -// HasDynamicModelMappingPublic is exported for testing. -func HasDynamicModelMappingPublic() bool { - return hasDynamicModelMapping() -} - -// hasDynamicModelMapping returns true if at least one model has been loaded from the LS. -func hasDynamicModelMapping() bool { - dynamicModelMapMu.RLock() - defer dynamicModelMapMu.RUnlock() - return len(dynamicModelMap) > 0 -} - -// RefreshModelMapping queries the LS for available models and builds the mapping. -// Called automatically when an LS instance starts. -func RefreshModelMapping(inst *Instance) bool { - if inst == nil { - return false - } - startedAt := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), lsModelConfigTimeout) - defer cancel() - - resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{}) - if err != nil { - inst.SetModelMappingReady(false) - if isPermanentModelMappingError(err) { - reason := modelMappingDeniedReason(err) - inst.SetModelMappingUnavailable(reason) - slog.Warn("[LS-POOL] Model mapping unavailable", - "account", inst.AccountID, - "replica", inst.Replica, - "address", inst.Address, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond), - "reason", reason) - return false - } - inst.ClearModelMappingUnavailable() - slog.Warn("[LS-POOL] Failed to get model config", - "account", inst.AccountID, - "replica", inst.Replica, - "address", inst.Address, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond), - "err", err) - return false - } - - var data struct { - ClientModelConfigs []struct { - Label string `json:"label"` - ModelOrAlias map[string]any `json:"modelOrAlias"` - } `json:"clientModelConfigs"` - } - if err := json.Unmarshal(resp, &data); err != nil { - inst.SetModelMappingReady(false) - inst.ClearModelMappingUnavailable() - return false - } - - newMap := make(map[string]int) - for _, cfg := range data.ClientModelConfigs { - label := cfg.Label - if label == "" { - continue - } - // modelOrAlias is {"model": "MODEL_PLACEHOLDER_M37"} in JSON - modelStr, _ := cfg.ModelOrAlias["model"].(string) - if modelStr == "" { - continue - } - // Parse "MODEL_PLACEHOLDER_M37" → 1037 - enumVal := parseModelEnumString(modelStr) - if enumVal > 0 { - // Store both the display label and a normalized form - newMap[label] = enumVal - // Also store kebab-case version: "Gemini 3 Flash" → "gemini-3-flash" - normalized := normalizeLabel(label) - if normalized != "" { - newMap[normalized] = enumVal - } - } - } - - if len(newMap) > 0 { - dynamicModelMapMu.Lock() - dynamicModelMap = newMap - dynamicModelMapMu.Unlock() - inst.SetModelMappingReady(true) - inst.ClearModelMappingUnavailable() - slog.Info("[LS-POOL] Model mapping refreshed", - "account", inst.AccountID, - "replica", inst.Replica, - "address", inst.Address, - "count", len(newMap)/2, - "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) - return true - } - inst.SetModelMappingReady(false) - inst.ClearModelMappingUnavailable() - return false -} - -func parseModelEnumString(s string) int { - // Named enums - named := map[string]int{ - "MODEL_CLAUDE_4_SONNET": 281, - "MODEL_CLAUDE_4_SONNET_THINKING": 282, - "MODEL_CLAUDE_4_OPUS": 290, - "MODEL_CLAUDE_4_OPUS_THINKING": 291, - "MODEL_CLAUDE_4_5_SONNET": 333, - "MODEL_CLAUDE_4_5_SONNET_THINKING": 334, - "MODEL_CLAUDE_4_5_HAIKU": 340, - "MODEL_CLAUDE_4_5_HAIKU_THINKING": 341, - "MODEL_OPENAI_GPT_OSS_120B_MEDIUM": 342, - "MODEL_GOOGLE_GEMINI_2_5_FLASH": 312, - "MODEL_GOOGLE_GEMINI_2_5_FLASH_THINKING": 313, - "MODEL_GOOGLE_GEMINI_2_5_FLASH_LITE": 330, - "MODEL_GOOGLE_GEMINI_2_5_PRO": 246, - } - if v, ok := named[s]; ok { - return v - } - // "MODEL_PLACEHOLDER_M37" → 1037 - if strings.HasPrefix(s, "MODEL_PLACEHOLDER_M") { - numStr := strings.TrimPrefix(s, "MODEL_PLACEHOLDER_M") - n, err := strconv.Atoi(numStr) - if err == nil { - return 1000 + n - } - } - return 0 -} - -func normalizeLabel(label string) string { - s := strings.ToLower(label) - s = strings.ReplaceAll(s, " ", "-") - s = strings.ReplaceAll(s, ".", "-") - s = strings.ReplaceAll(s, "(", "") - s = strings.ReplaceAll(s, ")", "") - s = strings.ReplaceAll(s, "--", "-") - return strings.TrimRight(s, "-") -} - -func normalizeRequestedModelName(model string) string { - normalized := strings.ToLower(strings.TrimSpace(model)) - normalized = strings.TrimPrefix(normalized, "models/") - return normalized -} - -func isGeminiPlannerModel(model string) bool { - return strings.Contains(normalizeRequestedModelName(model), "gemini") -} - -func systemTextCompatible(stored, current string) bool { - stored = strings.TrimSpace(stored) - current = strings.TrimSpace(current) - return current == "" || current == stored -} - -// ============================================================ -// Helpers -// ============================================================ - -func buildSessionCacheKey(accountID int64, namespace, sessionID string) string { - return fmt.Sprintf("%d:%s:%s", accountID, namespace, sessionID) -} - -func userNamespace(req *http.Request) string { - if req == nil { - return "anon" - } - for _, value := range []string{ - req.Header.Get(userNamespaceHeader), - req.Header.Get("X-Api-Key"), - req.Header.Get("X-Goog-Api-Key"), - req.Header.Get("Authorization"), - } { - if strings.TrimSpace(value) != "" { - sum := sha256.Sum256([]byte(value)) - return fmt.Sprintf("%x", sum[:8]) - } - } - return "anon" -} - -func (u *LSPoolUpstream) getSessionState(key string) *cascadeSessionState { - u.sessionMu.Lock() - defer u.sessionMu.Unlock() - u.pruneExpiredSessionsLocked() - state := u.sessions[key] - if state == nil { - return nil - } - cloned := &cascadeSessionState{ - CascadeID: state.CascadeID, - SystemText: state.SystemText, - History: cloneConversationTurns(state.History), - UpdatedAt: state.UpdatedAt, - } - return cloned -} - -func (u *LSPoolUpstream) putSessionState(key string, state *cascadeSessionState) { - if state == nil { - return - } - u.sessionMu.Lock() - defer u.sessionMu.Unlock() - u.pruneExpiredSessionsLocked() - u.sessions[key] = &cascadeSessionState{ - CascadeID: state.CascadeID, - SystemText: state.SystemText, - History: cloneConversationTurns(state.History), - UpdatedAt: state.UpdatedAt, - } -} - -func (u *LSPoolUpstream) pruneExpiredSessionsLocked() { - now := time.Now() - for key, state := range u.sessions { - if state == nil || now.Sub(state.UpdatedAt) > sessionStateTTL { - delete(u.sessions, key) - } - } -} - -func isStreamGenerate(path string) bool { - return strings.Contains(path, "streamGenerateContent") -} - -// isQuotaExhaustedError detects 429 QUOTA_EXHAUSTED errors from LS cascade trajectory. -// When detected, the caller should fall back to direct HTTP so the gateway can -// inject enabledCreditTypes for AI Credits retry. -func isQuotaExhaustedError(msg string) bool { - lower := strings.ToLower(msg) - return (strings.Contains(lower, "resource_exhausted") || strings.Contains(lower, "quota_exhausted")) && - (strings.Contains(lower, "429") || strings.Contains(lower, "exhausted your capacity")) -} - -func isImageGenerationModelName(model string) bool { - modelLower := normalizeRequestedModelName(model) - return modelLower == "gemini-3.1-flash-image" || - modelLower == "gemini-3.1-flash-image-preview" || - strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") || - modelLower == "gemini-3-pro-image" || - modelLower == "gemini-3-pro-image-preview" || - strings.HasPrefix(modelLower, "gemini-3-pro-image-") || - modelLower == "gemini-2.5-flash-image" || - modelLower == "gemini-2.5-flash-image-preview" || - strings.HasPrefix(modelLower, "gemini-2.5-flash-image-") -} - -// requestHasTools checks if the Gemini request body contains tools/function declarations. -// These are not supported through the Cascade path and must use direct HTTP. -func requestHasTools(body []byte) bool { - // Check both the wrapped format {"request": {"tools": [...]}} and direct {"tools": [...]} - var outer map[string]json.RawMessage - if err := json.Unmarshal(body, &outer); err != nil { - return false - } - - // Check in wrapped request - if reqRaw, ok := outer["request"]; ok { - var inner map[string]json.RawMessage - if json.Unmarshal(reqRaw, &inner) == nil { - if tools, ok := inner["tools"]; ok && len(tools) > 4 { // > "[]" or "null" - return true - } - } - } - - // Check at top level - if tools, ok := outer["tools"]; ok && len(tools) > 4 { - return true - } - return false -} - -func snapshotRequestBody(req *http.Request) ([]byte, error) { - if req.Body == nil { - return nil, nil - } - body, err := io.ReadAll(req.Body) - if err != nil { - return nil, err - } - req.Body.Close() - req.Body = io.NopCloser(bytes.NewReader(body)) - return body, nil -} - -// unused but needed for compilation -var _ sync.Mutex diff --git a/backend/internal/pkg/lspool/worker_manager.go b/backend/internal/pkg/lspool/worker_manager.go deleted file mode 100644 index 19121e12..00000000 --- a/backend/internal/pkg/lspool/worker_manager.go +++ /dev/null @@ -1,680 +0,0 @@ -package lspool - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/json" - "fmt" - "io" - "log/slog" - "net" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/filters" - "github.com/docker/docker/api/types/network" - "github.com/docker/docker/client" - ocispec "github.com/opencontainers/image-spec/specs-go/v1" -) - -const ( - lsWorkerManagedByLabel = "managed-by" - lsWorkerManagedByValue = "sub2api" - lsWorkerAccountLabel = "account_id" - lsWorkerProxyHashLabel = "proxy_hash" - lsWorkerImageTagLabel = "image_tag" - lsWorkerControlPort = 18081 -) - -type workerManagerConfig struct { - Image string - Network string - DockerSocket string - IdleTTL time.Duration - MaxActive int - StartupTimeout time.Duration - RequestTimeout time.Duration -} - -type dockerClient interface { - ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) - ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) - ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error - ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) - ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error - ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error - Close() error -} - -type workerManager struct { - cfg workerManagerConfig - docker dockerClient - http *http.Client - - mu sync.Mutex - workers map[string]*workerHandle - state map[string]*workerAccountState - - ctx context.Context - cancel context.CancelFunc - logger *slog.Logger -} - -type workerHandle struct { - Key string - AccountID string - ProxyURL string - ProxyHash string - ContainerID string - Container string - Address string - AuthToken string - LastUsed time.Time - LastStateSHA string -} - -type workerAccountState struct { - HasToken bool `json:"has_token"` - AccessToken string `json:"access_token,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` - HasModelCredits bool `json:"has_model_credits"` - UseAICredits bool `json:"use_ai_credits"` - AvailableCredits *int32 `json:"available_credits,omitempty"` - MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"` -} - -func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) { - if cfg == nil { - return nil, fmt.Errorf("config is nil") - } - - managerCfg := workerManagerConfig{ - Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image), - Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network), - DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket), - IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL, - MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive, - StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout, - RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout, - } - - if managerCfg.Image == "" { - managerCfg.Image = "weishaw/sub2api-lsworker:latest" - } - if managerCfg.Network == "" { - managerCfg.Network = "sub2api-network" - } - if managerCfg.DockerSocket == "" { - managerCfg.DockerSocket = "unix:///var/run/docker.sock" - } - if managerCfg.IdleTTL <= 0 { - managerCfg.IdleTTL = 15 * time.Minute - } - if managerCfg.MaxActive < 1 { - managerCfg.MaxActive = 50 - } - if managerCfg.StartupTimeout <= 0 { - managerCfg.StartupTimeout = 45 * time.Second - } - if managerCfg.RequestTimeout <= 0 { - managerCfg.RequestTimeout = 60 * time.Second - } - - opts := []client.Opt{client.WithAPIVersionNegotiation()} - if managerCfg.DockerSocket != "" { - opts = append(opts, client.WithHost(managerCfg.DockerSocket)) - } else { - opts = append(opts, client.FromEnv) - } - - dockerClient, err := client.NewClientWithOpts(opts...) - if err != nil { - return nil, fmt.Errorf("create docker client: %w", err) - } - - return newWorkerManager(managerCfg, dockerClient) -} - -func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) { - ctx, cancel := context.WithCancel(context.Background()) - mgr := &workerManager{ - cfg: cfg, - docker: docker, - http: &http.Client{ - Timeout: cfg.RequestTimeout, - Transport: &http.Transport{ - Proxy: nil, - DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - MaxIdleConnsPerHost: 8, - }, - }, - workers: make(map[string]*workerHandle), - state: make(map[string]*workerAccountState), - ctx: ctx, - cancel: cancel, - logger: slog.Default().With("component", "lspool-worker-manager"), - } - if err := mgr.reconcileManagedContainers(ctx); err != nil { - cancel() - _ = docker.Close() - return nil, err - } - go mgr.cleanupLoop() - return mgr, nil -} - -func (m *workerManager) Close() { - m.cancel() - - m.mu.Lock() - workers := make([]*workerHandle, 0, len(m.workers)) - for _, handle := range m.workers { - workers = append(workers, handle) - } - m.workers = make(map[string]*workerHandle) - m.mu.Unlock() - - for _, handle := range workers { - m.removeWorkerContainer(context.Background(), handle) - } - if m.docker != nil { - _ = m.docker.Close() - } -} - -func (m *workerManager) Stats() map[string]any { - m.mu.Lock() - defer m.mu.Unlock() - return map[string]any{ - "accounts": len(m.state), - "total": len(m.workers), - "active": len(m.workers), - } -} - -func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { - m.mu.Lock() - defer m.mu.Unlock() - state := m.ensureStateLocked(accountID) - state.HasToken = true - state.AccessToken = accessToken - state.RefreshToken = refreshToken - if expiresAt.IsZero() { - state.ExpiresAt = nil - } else { - ts := expiresAt.UTC() - state.ExpiresAt = &ts - } -} - -func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { - m.mu.Lock() - defer m.mu.Unlock() - state := m.ensureStateLocked(accountID) - state.HasModelCredits = true - state.UseAICredits = useAICredits - state.AvailableCredits = cloneInt32Ptr(availableCredits) - state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage) -} - -func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) { - rawProxy := "" - if len(proxyURL) > 0 { - rawProxy = proxyURL[0] - } - normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy) - if err != nil { - return nil, err - } - if parsedProxy == nil { - return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID) - } - - replica := replicaSlotIndex(routingKey, parseLSReplicaCount()) - proxyHash := proxyHash(normalizedProxy) - workerKey := buildWorkerKey(accountID, proxyHash) - - m.mu.Lock() - state := cloneWorkerAccountState(m.state[accountID]) - if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" { - m.mu.Unlock() - return nil, fmt.Errorf("ls worker missing access token for account %s", accountID) - } - - handle := m.workers[workerKey] - if handle == nil { - if len(m.workers) >= m.cfg.MaxActive { - m.mu.Unlock() - return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive) - } - handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy) - if err != nil { - m.mu.Unlock() - return nil, err - } - m.workers[workerKey] = handle - } - handle.LastUsed = time.Now() - m.mu.Unlock() - - if err := m.waitForWorkerHealthy(handle); err != nil { - return nil, err - } - if err := m.syncWorkerState(handle, state); err != nil { - return nil, err - } - if err := m.waitForWorkerReady(handle, routingKey); err != nil { - return nil, err - } - - inst := &Instance{ - AccountID: accountID, - Replica: replica, - Address: handle.Address, - client: m.http, - healthy: true, - lastUsed: time.Now(), - modelMapReady: 1, - remote: true, - workerToken: handle.AuthToken, - routingKey: routingKey, - } - return inst, nil -} - -func (m *workerManager) cleanupLoop() { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - - for { - select { - case <-m.ctx.Done(): - return - case <-ticker.C: - m.collectIdleWorkers() - } - } -} - -func (m *workerManager) collectIdleWorkers() { - now := time.Now() - var expired []*workerHandle - - m.mu.Lock() - for key, handle := range m.workers { - if handle == nil { - delete(m.workers, key) - continue - } - if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL { - continue - } - expired = append(expired, handle) - delete(m.workers, key) - } - m.mu.Unlock() - - for _, handle := range expired { - m.removeWorkerContainer(context.Background(), handle) - } -} - -func (m *workerManager) reconcileManagedContainers(ctx context.Context) error { - args := filters.NewArgs() - args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue)) - - containers, err := m.docker.ContainerList(ctx, container.ListOptions{ - All: true, - Filters: args, - }) - if err != nil { - return fmt.Errorf("list managed ls workers: %w", err) - } - - for _, summary := range containers { - handle := &workerHandle{ - ContainerID: summary.ID, - Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"), - } - if err := m.removeWorkerContainer(ctx, handle); err != nil { - return err - } - } - return nil -} - -func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) { - containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8]) - authToken := generateUUID() - - proxyHost := parsedProxy.Hostname() - proxyPort := parsedProxy.Port() - if proxyPort == "" { - proxyPort = "1080" - } - proxyUser := parsedProxy.User.Username() - proxyPass, _ := parsedProxy.User.Password() - - labels := map[string]string{ - lsWorkerManagedByLabel: lsWorkerManagedByValue, - lsWorkerAccountLabel: accountID, - lsWorkerProxyHashLabel: proxyHash, - lsWorkerImageTagLabel: m.cfg.Image, - } - - env := []string{ - "ANTIGRAVITY_APP_ROOT=/app/ls", - fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID), - fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken), - fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort), - fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"), - fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL), - fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost), - fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort), - fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser), - fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass), - fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort), - fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()), - } - if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" { - env = append(env, "TZ="+tz) - } - - createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{ - Image: m.cfg.Image, - Labels: labels, - Env: env, - }, &container.HostConfig{ - CapAdd: []string{"NET_ADMIN"}, - }, &network.NetworkingConfig{ - EndpointsConfig: map[string]*network.EndpointSettings{ - m.cfg.Network: {}, - }, - }, nil, containerName) - if err != nil { - return nil, fmt.Errorf("create ls worker container: %w", err) - } - - if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil { - _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) - return nil, fmt.Errorf("start ls worker container: %w", err) - } - - inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID) - if err != nil { - _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) - return nil, fmt.Errorf("inspect ls worker container: %w", err) - } - - address, err := workerAddressFromInspect(inspect, m.cfg.Network) - if err != nil { - _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) - return nil, err - } - - m.logger.Info("created ls worker", - "account", shortAccountID(accountID), - "container", containerName, - "address", address, - "proxy_hash", proxyHash[:8]) - - return &workerHandle{ - Key: buildWorkerKey(accountID, proxyHash), - AccountID: accountID, - ProxyURL: proxyURL, - ProxyHash: proxyHash, - ContainerID: createResp.ID, - Container: containerName, - Address: address, - AuthToken: authToken, - LastUsed: time.Now(), - }, nil -} - -func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) { - if inspect.NetworkSettings == nil { - return "", fmt.Errorf("ls worker inspect missing network settings") - } - if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { - return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil - } - for _, endpoint := range inspect.NetworkSettings.Networks { - if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { - return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil - } - } - return "", fmt.Errorf("ls worker missing IP address on network %s", networkName) -} - -func firstContainerName(names []string) string { - if len(names) == 0 { - return "" - } - return names[0] -} - -func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error { - ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) - defer cancel() - - for { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil) - if err != nil { - return err - } - req.Header.Set("X-Worker-Token", handle.AuthToken) - resp, err := m.http.Do(req) - if err == nil { - _ = resp.Body.Close() - if resp.StatusCode == http.StatusOK { - return nil - } - } - - select { - case <-ctx.Done(): - return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err()) - case <-time.After(500 * time.Millisecond): - } - } -} - -func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error { - ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) - defer cancel() - - values := url.Values{} - if strings.TrimSpace(routingKey) != "" { - values.Set("routing_key", routingKey) - } - - var ( - lastStatus int - lastBody string - ) - - for { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil) - if err != nil { - return err - } - req.Header.Set("X-Worker-Token", handle.AuthToken) - resp, err := m.http.Do(req) - if err == nil { - body, _ := io.ReadAll(resp.Body) - _ = resp.Body.Close() - lastStatus = resp.StatusCode - lastBody = truncate(string(body), 200) - if resp.StatusCode == http.StatusOK { - return nil - } - if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) { - return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody)) - } - if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) { - m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200)) - } - } - - select { - case <-ctx.Done(): - if lastStatus > 0 || lastBody != "" { - return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err()) - } - return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err()) - case <-time.After(500 * time.Millisecond): - } - } -} - -func shouldWarnWorkerNotReady(status int, body string) bool { - if status == http.StatusServiceUnavailable { - normalized := strings.ToLower(strings.TrimSpace(body)) - if strings.Contains(normalized, "model mapping not ready") { - return false - } - } - return true -} - -func isWorkerModelMappingUnavailable(status int, body string) bool { - if status != http.StatusServiceUnavailable { - return false - } - normalized := strings.ToLower(strings.TrimSpace(body)) - return strings.Contains(normalized, errLSModelMapDenied.Error()) -} - -func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error { - if state == nil { - return fmt.Errorf("ls worker state is nil") - } - body, err := json.Marshal(state) - if err != nil { - return fmt.Errorf("marshal worker state: %w", err) - } - - sum := fmt.Sprintf("%x", sha256.Sum256(body)) - if handle.LastStateSHA == sum { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body)) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Worker-Token", handle.AuthToken) - - resp, err := m.http.Do(req) - if err != nil { - return fmt.Errorf("sync worker state: %w", err) - } - defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200)) - } - handle.LastStateSHA = sum - return nil -} - -func workerURL(handle *workerHandle, path string, values url.Values) string { - base := url.URL{ - Scheme: "http", - Host: handle.Address, - Path: path, - } - if values != nil { - base.RawQuery = values.Encode() - } - return base.String() -} - -func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error { - if handle == nil || strings.TrimSpace(handle.ContainerID) == "" { - return nil - } - timeout := 5 - _ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout}) - if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil { - return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err) - } - return nil -} - -func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState { - state := m.state[accountID] - if state == nil { - state = &workerAccountState{} - m.state[accountID] = state - } - return state -} - -func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) { - resolved := resolveLSProxy(proxyURL) - normalized, parsed, err := proxyurl.Parse(resolved) - if err != nil { - return "", nil, err - } - if parsed == nil { - return "", nil, nil - } - switch strings.ToLower(parsed.Scheme) { - case "socks5", "socks5h": - return normalized, parsed, nil - default: - return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme) - } -} - -func proxyHash(proxyURL string) string { - if strings.TrimSpace(proxyURL) == "" { - return "direct" - } - sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL))) - return fmt.Sprintf("%x", sum[:]) -} - -func buildWorkerKey(accountID, proxyHash string) string { - return accountID + ":" + proxyHash -} - -func cloneInt32Ptr(v *int32) *int32 { - if v == nil { - return nil - } - cp := *v - return &cp -} - -func cloneWorkerAccountState(state *workerAccountState) *workerAccountState { - if state == nil { - return nil - } - cp := *state - cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits) - cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount) - if state.ExpiresAt != nil { - ts := *state.ExpiresAt - cp.ExpiresAt = &ts - } - return &cp -} diff --git a/backend/internal/pkg/lspool/worker_manager_test.go b/backend/internal/pkg/lspool/worker_manager_test.go deleted file mode 100644 index 61fbe18d..00000000 --- a/backend/internal/pkg/lspool/worker_manager_test.go +++ /dev/null @@ -1,335 +0,0 @@ -package lspool - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - "time" - - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/filters" - "github.com/docker/docker/api/types/network" - ocispec "github.com/opencontainers/image-spec/specs-go/v1" - "github.com/stretchr/testify/require" -) - -type fakeDockerClient struct { - mu sync.Mutex - - listResp []container.Summary - listCalls int - createCalls int - startCalls int - stopCalls int - removeCalls int - inspectCalls int - removedIDs []string - createdConfigs []*container.Config - inspectResp container.InspectResponse -} - -func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) { - f.mu.Lock() - defer f.mu.Unlock() - f.listCalls++ - return append([]container.Summary(nil), f.listResp...), nil -} - -func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) { - f.mu.Lock() - defer f.mu.Unlock() - f.createCalls++ - f.createdConfigs = append(f.createdConfigs, cfg) - return container.CreateResponse{ID: "worker-created"}, nil -} - -func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { - f.mu.Lock() - defer f.mu.Unlock() - f.startCalls++ - return nil -} - -func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) { - f.mu.Lock() - defer f.mu.Unlock() - f.inspectCalls++ - return f.inspectResp, nil -} - -func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { - f.mu.Lock() - defer f.mu.Unlock() - f.stopCalls++ - return nil -} - -func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { - f.mu.Lock() - defer f.mu.Unlock() - f.removeCalls++ - f.removedIDs = append(f.removedIDs, containerID) - return nil -} - -func (f *fakeDockerClient) Close() error { return nil } - -func TestResolveWorkerProxyRejectsHTTP(t *testing.T) { - _, _, err := resolveWorkerProxy("http://127.0.0.1:7890") - require.Error(t, err) - require.Contains(t, err.Error(), "only supports socks5/socks5h") -} - -func TestProxyHashUsesNormalizedProxy(t *testing.T) { - normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080") - require.NoError(t, err) - require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized) - - hash1 := proxyHash(normalized) - hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080") - require.Equal(t, hash1, hash2) -} - -func TestWorkerManagerRequiresToken(t *testing.T) { - fakeDocker := &fakeDockerClient{} - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 2, - StartupTimeout: time.Second, - RequestTimeout: time.Second, - }, fakeDocker) - require.NoError(t, err) - defer manager.Close() - - _, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080") - require.Error(t, err) - require.Contains(t, err.Error(), "missing access token") -} - -func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) { - var mu sync.Mutex - var healthCalls int - var readyCalls int - var stateCalls int - var stateBodies [][]byte - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/healthz": - mu.Lock() - healthCalls++ - mu.Unlock() - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - case "/readyz": - mu.Lock() - readyCalls++ - mu.Unlock() - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ready")) - case "/account/state": - body, _ := io.ReadAll(r.Body) - mu.Lock() - stateCalls++ - stateBodies = append(stateBodies, body) - mu.Unlock() - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - default: - http.NotFound(w, r) - } - })) - defer server.Close() - - fakeDocker := &fakeDockerClient{} - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 4, - StartupTimeout: time.Second, - RequestTimeout: time.Second, - }, fakeDocker) - require.NoError(t, err) - defer manager.Close() - - accountID := "9" - proxyURL := "socks5h://user:pass@127.0.0.1:1080" - hash := proxyHash(proxyURL) - key := buildWorkerKey(accountID, hash) - - manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour)) - manager.mu.Lock() - manager.workers[key] = &workerHandle{ - Key: key, - AccountID: accountID, - ProxyURL: proxyURL, - ProxyHash: hash, - ContainerID: "existing-worker", - Container: "sub2api-ls-9-test", - Address: strings.TrimPrefix(server.URL, "http://"), - AuthToken: "worker-token", - LastUsed: time.Now(), - } - manager.mu.Unlock() - - inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL) - require.NoError(t, err) - require.True(t, inst1.remote) - require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica) - - inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL) - require.NoError(t, err) - require.True(t, inst2.remote) - - mu.Lock() - defer mu.Unlock() - require.GreaterOrEqual(t, healthCalls, 2) - require.GreaterOrEqual(t, readyCalls, 2) - require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged") - require.Len(t, stateBodies, 1) - - var synced workerAccountState - require.NoError(t, json.Unmarshal(stateBodies[0], &synced)) - require.True(t, synced.HasToken) - require.Equal(t, "ya29.test", synced.AccessToken) -} - -func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) { - fakeDocker := &fakeDockerClient{} - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 1, - StartupTimeout: time.Second, - RequestTimeout: time.Second, - }, fakeDocker) - require.NoError(t, err) - defer manager.Close() - - manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour)) - manager.mu.Lock() - manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()} - manager.mu.Unlock() - - _, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080") - require.Error(t, err) - require.Contains(t, err.Error(), "limit reached") - require.Equal(t, 0, fakeDocker.createCalls) -} - -func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) { - fakeDocker := &fakeDockerClient{ - listResp: []container.Summary{ - { - ID: "old-worker-1", - Names: []string{"/sub2api-ls-9-deadbeef"}, - }, - { - ID: "old-worker-2", - Names: []string{"/sub2api-ls-10-beadfeed"}, - }, - }, - } - - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 4, - StartupTimeout: time.Second, - RequestTimeout: time.Second, - }, fakeDocker) - require.NoError(t, err) - defer manager.Close() - - require.Equal(t, 1, fakeDocker.listCalls) - require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs) -} - -func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) { - fakeDocker := &fakeDockerClient{} - _, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()}) - require.NoError(t, err) -} - -func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) { - require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0")) - require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured")) - require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed")) -} - -func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/readyz", r.URL.Path) - w.WriteHeader(http.StatusServiceUnavailable) - _, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`)) - })) - defer server.Close() - - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 1, - StartupTimeout: time.Second, - RequestTimeout: time.Second, - }, &fakeDockerClient{}) - require.NoError(t, err) - defer manager.Close() - - handle := &workerHandle{ - Container: "sub2api-ls-test", - Address: strings.TrimPrefix(server.URL, "http://"), - AuthToken: "worker-token", - } - - err = manager.waitForWorkerReady(handle, "") - require.Error(t, err) - require.ErrorIs(t, err, errLSModelMapDenied) -} - -func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/readyz", r.URL.Path) - w.WriteHeader(http.StatusServiceUnavailable) - _, _ = w.Write([]byte("worker model mapping not ready for replica 0\n")) - })) - defer server.Close() - - manager, err := newWorkerManager(workerManagerConfig{ - Image: "worker:latest", - Network: "sub2api-network", - DockerSocket: "unix:///var/run/docker.sock", - IdleTTL: time.Minute, - MaxActive: 1, - StartupTimeout: 100 * time.Millisecond, - RequestTimeout: time.Second, - }, &fakeDockerClient{}) - require.NoError(t, err) - defer manager.Close() - - handle := &workerHandle{ - Container: "sub2api-ls-test", - Address: strings.TrimPrefix(server.URL, "http://"), - AuthToken: "worker-token", - } - - err = manager.waitForWorkerReady(handle, "") - require.Error(t, err) - require.Contains(t, err.Error(), `last_status=503`) - require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`) -} diff --git a/backend/internal/pkg/lspool/worker_server.go b/backend/internal/pkg/lspool/worker_server.go deleted file mode 100644 index f41b05de..00000000 --- a/backend/internal/pkg/lspool/worker_server.go +++ /dev/null @@ -1,374 +0,0 @@ -package lspool - -import ( - "context" - "encoding/json" - "fmt" - "io" - "log/slog" - "net/http" - "os" - "strconv" - "strings" - "sync" - "time" -) - -type WorkerServerConfig struct { - AccountID string - AuthToken string - ListenAddr string - AppRoot string - NetworkReadyFile string - MaxIdleTime time.Duration - HealthInterval time.Duration -} - -type WorkerServer struct { - cfg WorkerServerConfig - pool *Pool - logger *slog.Logger - - mu sync.RWMutex - state workerAccountState -} - -func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) { - if strings.TrimSpace(cfg.AccountID) == "" { - return nil, fmt.Errorf("worker account id is required") - } - if strings.TrimSpace(cfg.AuthToken) == "" { - return nil, fmt.Errorf("worker auth token is required") - } - if strings.TrimSpace(cfg.ListenAddr) == "" { - cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort) - } - if strings.TrimSpace(cfg.AppRoot) == "" { - cfg.AppRoot = "/app/ls" - } - if cfg.MaxIdleTime <= 0 { - cfg.MaxIdleTime = 15 * time.Minute - } - if cfg.HealthInterval <= 0 { - cfg.HealthInterval = 30 * time.Second - } - - poolCfg := DefaultConfig() - poolCfg.AppRoot = cfg.AppRoot - poolCfg.MaxIdleTime = cfg.MaxIdleTime - poolCfg.HealthCheckInterval = cfg.HealthInterval - - return &WorkerServer{ - cfg: cfg, - pool: NewPool(poolCfg), - logger: slog.Default().With("component", "lsworker"), - }, nil -} - -func NewWorkerServerFromEnv() (*WorkerServer, error) { - maxIdleTime := 15 * time.Minute - if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" { - if parsed, err := time.ParseDuration(raw); err == nil { - maxIdleTime = parsed - } - } - healthInterval := 30 * time.Second - if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" { - if parsed, err := time.ParseDuration(raw); err == nil { - healthInterval = parsed - } - } - - return NewWorkerServer(WorkerServerConfig{ - AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")), - AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")), - ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")), - AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")), - NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")), - MaxIdleTime: maxIdleTime, - HealthInterval: healthInterval, - }) -} - -func (s *WorkerServer) Close() { - if s.pool != nil { - s.pool.Close() - } -} - -func (s *WorkerServer) Handler() http.Handler { - mux := http.NewServeMux() - mux.HandleFunc("/healthz", s.handleHealthz) - mux.HandleFunc("/readyz", s.handleReadyz) - mux.HandleFunc("/account/state", s.handleAccountState) - mux.HandleFunc("/rpc/unary", s.handleRPCUnary) - mux.HandleFunc("/rpc/stream", s.handleRPCStream) - return mux -} - -func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) { - if !s.authorize(w, r) { - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) -} - -func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) { - if !s.authorize(w, r) { - return - } - routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key")) - inst, err := s.ensureReady(r.Context(), routingKey) - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica))) -} - -func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) { - if !s.authorize(w, r) { - return - } - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - defer r.Body.Close() - var payload workerAccountState - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - http.Error(w, "invalid account state payload", http.StatusBadRequest) - return - } - - s.mu.Lock() - s.state = *cloneWorkerAccountState(&payload) - s.mu.Unlock() - s.applyState() - - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) -} - -func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) { - if !s.authorize(w, r) { - return - } - service, method, mode, routingKey, ok := parseRPCRequest(w, r) - if !ok { - return - } - - inst, err := s.ensureReady(r.Context(), routingKey) - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "read request body failed", http.StatusBadRequest) - return - } - if len(body) == 0 { - body = []byte("{}") - } - - var respBody []byte - switch mode { - case "json": - var input any - if err := json.Unmarshal(body, &input); err != nil { - http.Error(w, "invalid json rpc body", http.StatusBadRequest) - return - } - respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input) - case "proto": - respBody, err = inst.CallRPC(r.Context(), service, method, body) - default: - http.Error(w, "unsupported rpc mode", http.StatusBadRequest) - return - } - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - - w.WriteHeader(http.StatusOK) - _, _ = w.Write(respBody) -} - -func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) { - if !s.authorize(w, r) { - return - } - service, method, mode, routingKey, ok := parseRPCRequest(w, r) - if !ok { - return - } - - inst, err := s.ensureReady(r.Context(), routingKey) - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "read request body failed", http.StatusBadRequest) - return - } - - var resp *http.Response - switch mode { - case "json": - var input any - if len(body) == 0 { - body = []byte("{}") - } - if err := json.Unmarshal(body, &input); err != nil { - http.Error(w, "invalid json rpc body", http.StatusBadRequest) - return - } - resp, err = inst.StreamRPCJSON(r.Context(), service, method, input) - case "proto": - resp, err = inst.StreamRPC(r.Context(), service, method, body) - default: - http.Error(w, "unsupported rpc mode", http.StatusBadRequest) - return - } - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) -} - -func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool { - if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) { - return true - } - http.Error(w, "unauthorized", http.StatusUnauthorized) - return false -} - -func subtleHeaderEqual(left, right string) bool { - if left == "" || right == "" { - return false - } - return left == right -} - -func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) { - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return "", "", "", "", false - } - query := r.URL.Query() - service = strings.TrimSpace(query.Get("service")) - method = strings.TrimSpace(query.Get("method")) - mode = strings.ToLower(strings.TrimSpace(query.Get("mode"))) - routingKey = strings.TrimSpace(query.Get("routing_key")) - if service == "" || method == "" { - http.Error(w, "missing rpc target", http.StatusBadRequest) - return "", "", "", "", false - } - if mode == "" { - mode = "proto" - } - return service, method, mode, routingKey, true -} - -func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) { - if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" { - if _, err := os.Stat(path); err != nil { - return nil, fmt.Errorf("worker network not ready: %w", err) - } - } - - s.applyState() - s.mu.RLock() - state := cloneWorkerAccountState(&s.state) - s.mu.RUnlock() - if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" { - return nil, fmt.Errorf("worker access token not configured") - } - - inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "") - if err != nil { - return nil, err - } - if inst.HasModelMappingUnavailable() { - return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason()) - } - if inst.HasModelMappingReady() { - return inst, nil - } - - modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout) - defer cancel() - _ = modelCtx - if !RefreshModelMapping(inst) { - if inst.HasModelMappingUnavailable() { - return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason()) - } - return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica) - } - return inst, nil -} - -func (s *WorkerServer) applyState() { - s.mu.RLock() - state := cloneWorkerAccountState(&s.state) - s.mu.RUnlock() - if state == nil { - return - } - if state.HasToken { - expiresAt := time.Time{} - if state.ExpiresAt != nil { - expiresAt = state.ExpiresAt.UTC() - } - s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt) - } - if state.HasModelCredits { - s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount) - } -} - -func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server { - return &http.Server{ - Addr: listenAddr, - Handler: handler, - ReadHeaderTimeout: 10 * time.Second, - } -} - -func workerExitCode(err error) int { - if err == nil { - return 0 - } - return 1 -} - -func parseWorkerControlPort() int { - raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT")) - if raw == "" { - return lsWorkerControlPort - } - port, err := strconv.Atoi(raw) - if err != nil || port < 1 { - return lsWorkerControlPort - } - return port -} diff --git a/backend/internal/server/routes/common.go b/backend/internal/server/routes/common.go index 4989358d..b2861ddb 100644 --- a/backend/internal/server/routes/common.go +++ b/backend/internal/server/routes/common.go @@ -1,11 +1,20 @@ package routes import ( + "bytes" + "context" + "io" "net/http" + "time" "github.com/gin-gonic/gin" ) +const ( + anthropicEventLoggingURL = "https://api.anthropic.com/api/event_logging/batch" + eventLoggingForwardTimeout = 8 * time.Second +) + // RegisterCommonRoutes 注册通用路由(健康检查、状态等) func RegisterCommonRoutes(r *gin.Engine) { // 健康检查 @@ -13,8 +22,36 @@ func RegisterCommonRoutes(r *gin.Engine) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) - // Claude Code 遥测日志(忽略,直接返回200) + // Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。 + // 删除 baseUrl/gateway 字段防止网关地址暴露(见 FINGERPRINT_SECURITY_REPORT.md §GAP-1/2)。 + // 转发而非丢弃,避免"高流量零遥测"异常被检测。 r.POST("/api/event_logging/batch", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil || len(body) == 0 { + c.Status(http.StatusOK) + return + } + + sanitized := sanitizeEventBatch(body) + + ctx, cancel := context.WithTimeout(c.Request.Context(), eventLoggingForwardTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicEventLoggingURL, bytes.NewReader(sanitized)) + if err != nil { + c.Status(http.StatusOK) + return + } + req.Header.Set("Content-Type", "application/json") + // 透传客户端的 Authorization header(OAuth Bearer token) + if auth := c.GetHeader("Authorization"); auth != "" { + req.Header.Set("Authorization", auth) + } + + resp, err := http.DefaultClient.Do(req) + if err == nil { + resp.Body.Close() + } c.Status(http.StatusOK) }) diff --git a/backend/internal/service/gateway_attribution.go b/backend/internal/service/gateway_attribution.go index 26a47f90..e8beb871 100644 --- a/backend/internal/service/gateway_attribution.go +++ b/backend/internal/service/gateway_attribution.go @@ -4,6 +4,9 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "regexp" + "sync" + "time" "github.com/google/uuid" "github.com/tidwall/gjson" @@ -12,7 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) -// Attribution block constants matching real Claude Code 2.1.88. +// Attribution block constants matching real Claude Code 2.1.89. // Source: src/constants/system.ts + src/utils/fingerprint.ts const ( // fingerprintSalt must match the hardcoded salt in the real CLI. @@ -81,11 +84,10 @@ func extractFirstUserMessageText(body []byte) string { // Source: extracted/src/constants/system.ts:73-95 func buildAttributionBlock(cliVersion, fingerprint string) string { version := cliVersion + "." + fingerprint - // 注意:cch 字段由 Bun 的 NATIVE_CLIENT_ATTESTATION 编译时 feature 控制。 - // npm 安装版本(非原生二进制)此 feature 为 false,所以不包含 cch 字段。 - // 只有原生二进制安装(Bun 打包)才会有 cch,且其值会被 Bun 的 Zig 层替换为真实 hash。 - // 我们模拟 npm 安装版本的行为:不包含 cch。 - return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli;", version) + // 2.1.89 起 cch=00000 出现在所有安装模式(含 npm 版),不再只限于原生二进制。 + // 原生二进制由 Bun 的 Zig 层在运行时将 00000 替换为真实 attestation hash; + // 普通安装版保持 00000 占位符不变。 + return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli; cch=00000;", version) } // injectAttributionBlock prepends the x-anthropic-billing-header attribution block @@ -163,20 +165,89 @@ func injectAttributionBlock(body []byte, cliVersion string) []byte { } } -// generateSessionIDForAccount generates a deterministic per-account session UUID -// that remains stable within a process-like timeframe. -// Uses instanceSalt + accountID to ensure uniqueness across sub2api instances. -func generateSessionIDForAccount(instanceSalt string, accountID int64) string { - // Use a per-account stable UUID (like real CLI's per-process UUID). - // We use accountID as the base — each account gets a different "session". - seed := fmt.Sprintf("session:%s:%d", instanceSalt, accountID) - hash := sha256.Sum256([]byte(seed)) - sessionUUID, err := uuid.FromBytes(hash[:16]) - if err != nil { - return uuid.New().String() - } - // Set UUID v4 variant - sessionUUID[6] = (sessionUUID[6] & 0x0f) | 0x40 - sessionUUID[8] = (sessionUUID[8] & 0x3f) | 0x80 - return sessionUUID.String() +// cliSessionEntry holds a cached session UUID with an expiration time. +type cliSessionEntry struct { + id string + expiresAt time.Time +} + +// cliSessionCache stores per-account session UUIDs that rotate on a TTL. +// Real CLI creates a new random UUID per process invocation; we approximate +// this by rotating every 30-60 minutes (jittered per account). +var ( + cliSessionCache = make(map[int64]cliSessionEntry) + cliSessionCacheMu sync.Mutex +) + +// sessionTTLBase is the base TTL for session ID rotation. +const sessionTTLBase = 30 * time.Minute + +// generateSessionIDForAccount returns a per-account session UUID that rotates +// periodically. Each account gets a random TTL jitter (0-30 min on top of +// the 30 min base) so accounts don't all rotate simultaneously. +func generateSessionIDForAccount(instanceSalt string, accountID int64) string { + cliSessionCacheMu.Lock() + defer cliSessionCacheMu.Unlock() + + now := time.Now() + if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) { + return entry.id + } + + // Compute per-account jitter from a hash so the same account always gets + // the same jitter within a process (avoids re-rolling on every rotation). + jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID) + h := sha256.Sum256([]byte(jitterSeed)) + jitterMinutes := int(h[0]) % 31 // 0-30 minutes + ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute + + newID := uuid.New().String() + cliSessionCache[accountID] = cliSessionEntry{ + id: newID, + expiresAt: now.Add(ttl), + } + return newID +} + +// reUserHome matches /Users// or /home// path segments. +// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username. +var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`) + +// reEnvLine matches lines of the form "Key: value" for the environment block +// fields injected by Claude Code's CLAUDE.md / sysprompt machinery. +var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`) + +// canonicalEnvValues maps environment block keys to their canonical replacements. +// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine. +var canonicalEnvValues = map[string]string{ + "Platform": "Platform: darwin", + "Shell": "Shell: zsh", + "OS Version": "OS Version: Darwin 24.4.0", + "Working directory": "Working directory: /Users/user/project", +} + +// NormalizeSystemPromptEnv rewrites environment-specific fields in a system +// prompt text block to canonical values, preventing real machine fingerprinting. +// +// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText): +// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0 +// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/" +// +// Only called on system prompt text blocks, never on user message content. +func NormalizeSystemPromptEnv(text string) string { + // Replace env-info lines with canonical values + text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string { + for key, canonical := range canonicalEnvValues { + if len(line) >= len(key) && line[:len(key)] == key { + return canonical + } + } + return line + }) + + // Redact real usernames in home directory paths + // e.g. /Users/alice/project -> /Users/user/project + text = reUserHome.ReplaceAllString(text, "${1}user/") + + return text } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e59412eb..ed2be3dc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -895,6 +895,9 @@ func sanitizeSystemText(text string) string { "You are OpenCode, the best coding agent on the planet.", strings.TrimSpace(claudeCodeSystemPrompt), ) + // Normalize environment block fields (Platform/Shell/OS Version/Working directory) + // to canonical values so different client machines don't create fingerprint divergence. + text = NormalizeSystemPromptEnv(text) return text } @@ -5773,7 +5776,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.HaikuBetaHeader } - return claude.DefaultBetaHeader + return claude.GetOAuthBetaHeader(modelID) } func requestNeedsBetaFeatures(body []byte) bool { @@ -5790,10 +5793,7 @@ func requestNeedsBetaFeatures(body []byte) bool { func defaultAPIKeyBetaHeader(body []byte) string { modelID := gjson.GetBytes(body, "model").String() - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.APIKeyHaikuBetaHeader - } - return claude.APIKeyBetaHeader + return claude.GetAPIKeyBetaHeader(modelID) } func applyClaudeOAuthHeaderDefaults(req *http.Request) { diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 6a3cdcf4..f5f180b0 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -26,7 +26,7 @@ var ( // 默认指纹值(当客户端未提供时使用) var defaultFingerprint = Fingerprint{ - UserAgent: "claude-cli/2.1.88 (external, cli)", + UserAgent: "claude-cli/2.1.89 (external, cli)", StainlessLang: "js", StainlessPackageVersion: "0.74.0", StainlessOS: "MacOS", diff --git a/backend/internal/service/lspool_bootstrap_service.go b/backend/internal/service/lspool_bootstrap_service.go deleted file mode 100644 index fbe53e28..00000000 --- a/backend/internal/service/lspool_bootstrap_service.go +++ /dev/null @@ -1,225 +0,0 @@ -package service - -import ( - "context" - "fmt" - "log/slog" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" -) - -const ( - defaultLSPoolBootstrapConcurrency = 4 -) - -type lsBootstrapAccountReader interface { - ListByPlatform(ctx context.Context, platform string) ([]Account, error) -} - -// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup. -type LSPoolBootstrapService struct { - accountReader lsBootstrapAccountReader - backend lspool.Backend - cfg *config.Config - logger *slog.Logger - - ctx context.Context - cancel context.CancelFunc - - once sync.Once - wg sync.WaitGroup -} - -func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService { - ctx, cancel := context.WithCancel(context.Background()) - return &LSPoolBootstrapService{ - accountReader: accountReader, - backend: backend, - cfg: cfg, - logger: slog.Default().With("component", "service.lspool_bootstrap"), - ctx: ctx, - cancel: cancel, - } -} - -// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker. -func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService { - svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg) - svc.Start() - return svc -} - -func (s *LSPoolBootstrapService) Start() { - if s == nil { - return - } - s.once.Do(func() { - if s.backend == nil { - if lspool.IsLSModeEnabled() { - s.logger.Warn("startup bootstrap skipped: ls backend unavailable") - } - return - } - s.wg.Add(1) - go func() { - defer s.wg.Done() - s.bootstrap(s.ctx) - }() - }) -} - -func (s *LSPoolBootstrapService) Stop() { - if s == nil { - return - } - s.cancel() - s.wg.Wait() -} - -func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) { - if s.backend == nil || s.accountReader == nil { - return - } - - accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity) - if err != nil { - s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err) - return - } - - now := time.Now() - candidates := make([]Account, 0, len(accounts)) - for i := range accounts { - if shouldBootstrapLSPoolAccount(&accounts[i], now) { - candidates = append(candidates, accounts[i]) - } - } - - if len(candidates) == 0 { - s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts") - return - } - - s.logger.Info("starting ls worker bootstrap", - "accounts_total", len(accounts), - "accounts_eligible", len(candidates), - "concurrency", s.bootstrapConcurrency()) - - var ( - mu sync.Mutex - started int - failed int - ) - sem := make(chan struct{}, s.bootstrapConcurrency()) - var wg sync.WaitGroup - -loop: - for i := range candidates { - account := candidates[i] - select { - case <-ctx.Done(): - break loop - case sem <- struct{}{}: - } - - wg.Add(1) - go func(account Account) { - defer wg.Done() - defer func() { <-sem }() - - if err := s.bootstrapAccount(&account); err != nil { - mu.Lock() - failed++ - mu.Unlock() - s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err) - return - } - - mu.Lock() - started++ - mu.Unlock() - s.logger.Info("bootstrap ls worker ready", "account_id", account.ID) - }(account) - } - - wg.Wait() - s.logger.Info("ls worker bootstrap completed", - "accounts_total", len(accounts), - "accounts_eligible", len(candidates), - "workers_ready", started, - "workers_failed", failed, - "canceled", ctx.Err() != nil) -} - -func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error { - if s.backend == nil { - return fmt.Errorf("ls backend unavailable") - } - if account == nil { - return fmt.Errorf("account is nil") - } - - accountKey := strconv.FormatInt(account.ID, 10) - accessToken := strings.TrimSpace(account.GetCredential("access_token")) - if accessToken == "" { - return fmt.Errorf("missing access token") - } - refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) - - expiresAt := time.Time{} - if ts := account.GetCredentialAsTime("expires_at"); ts != nil { - expiresAt = ts.UTC() - } - - s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt) - availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account) - s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount) - - proxyURL := "" - if account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil { - return fmt.Errorf("get or create ls worker: %w", err) - } - return nil -} - -func (s *LSPoolBootstrapService) bootstrapConcurrency() int { - parallelism := defaultLSPoolBootstrapConcurrency - if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism { - parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive - } - if parallelism < 1 { - return 1 - } - return parallelism -} - -func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool { - if account == nil { - return false - } - if account.Platform != PlatformAntigravity { - return false - } - if account.Type != AccountTypeOAuth { - return false - } - if account.Status != StatusActive || !account.Schedulable { - return false - } - if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) { - return false - } - if strings.TrimSpace(account.GetCredential("access_token")) == "" { - return false - } - return strings.TrimSpace(account.GetCredential("project_id")) != "" -} diff --git a/backend/internal/service/lspool_bootstrap_service_test.go b/backend/internal/service/lspool_bootstrap_service_test.go deleted file mode 100644 index e16f9ba7..00000000 --- a/backend/internal/service/lspool_bootstrap_service_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package service - -import ( - "context" - "errors" - "sync" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" - "github.com/stretchr/testify/require" -) - -type fakeLSBootstrapAccountReader struct { - mu sync.Mutex - accounts []Account - err error - platforms []string -} - -func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) { - f.mu.Lock() - f.platforms = append(f.platforms, platform) - accounts := append([]Account(nil), f.accounts...) - err := f.err - f.mu.Unlock() - return accounts, err -} - -type fakeLSPoolBackend struct { - mu sync.Mutex - tokenCalls map[string]fakeLSPoolTokenCall - creditCalls map[string]fakeLSPoolCreditCall - getCalls []fakeLSPoolGetCall - getErrs map[string]error -} - -type fakeLSPoolTokenCall struct { - AccessToken string - RefreshToken string - ExpiresAt time.Time -} - -type fakeLSPoolCreditCall struct { - UseAICredits bool - AvailableCredits *int32 - MinimumCreditAmount *int32 -} - -type fakeLSPoolGetCall struct { - AccountID string - RoutingKey string - ProxyURL string -} - -func newFakeLSPoolBackend() *fakeLSPoolBackend { - return &fakeLSPoolBackend{ - tokenCalls: make(map[string]fakeLSPoolTokenCall), - creditCalls: make(map[string]fakeLSPoolCreditCall), - getErrs: make(map[string]error), - } -} - -func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) { - rawProxy := "" - if len(proxyURL) > 0 { - rawProxy = proxyURL[0] - } - - f.mu.Lock() - defer f.mu.Unlock() - f.getCalls = append(f.getCalls, fakeLSPoolGetCall{ - AccountID: accountID, - RoutingKey: routingKey, - ProxyURL: rawProxy, - }) - if err := f.getErrs[accountID]; err != nil { - return nil, err - } - return &lspool.Instance{AccountID: accountID}, nil -} - -func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { - f.mu.Lock() - defer f.mu.Unlock() - f.tokenCalls[accountID] = fakeLSPoolTokenCall{ - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresAt: expiresAt, - } -} - -func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { - f.mu.Lock() - defer f.mu.Unlock() - f.creditCalls[accountID] = fakeLSPoolCreditCall{ - UseAICredits: useAICredits, - AvailableCredits: copyInt32Ptr(availableCredits), - MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage), - } -} - -func (f *fakeLSPoolBackend) Stats() map[string]any { return nil } - -func (f *fakeLSPoolBackend) Close() {} - -func copyInt32Ptr(v *int32) *int32 { - if v == nil { - return nil - } - cp := *v - return &cp -} - -func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) { - expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) - expiredAt := time.Now().Add(-2 * time.Hour) - reader := &fakeLSBootstrapAccountReader{ - accounts: []Account{ - { - ID: 101, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{ - "access_token": "token-101", - "refresh_token": "refresh-101", - "expires_at": expiresAt.Format(time.RFC3339), - "project_id": "proj-101", - }, - Extra: map[string]any{ - "allow_overages": true, - "ai_credits": []any{ - map[string]any{ - "credit_type": "GOOGLE_ONE_AI", - "amount": 120, - "minimum_balance": 55, - }, - }, - }, - Proxy: &Proxy{ - Protocol: "socks5h", - Host: "127.0.0.1", - Port: 1080, - Username: "alice", - Password: "secret", - }, - }, - { - ID: 102, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: false, - Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"}, - }, - { - ID: 103, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{"access_token": "token-103"}, - }, - { - ID: 104, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - AutoPauseOnExpired: true, - ExpiresAt: &expiredAt, - Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"}, - }, - { - ID: 106, - Platform: PlatformAntigravity, - Type: AccountTypeUpstream, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"}, - }, - { - ID: 105, - Platform: PlatformOpenAI, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{"access_token": "token-105"}, - }, - }, - } - backend := newFakeLSPoolBackend() - svc := NewLSPoolBootstrapService(reader, backend, &config.Config{ - Gateway: config.GatewayConfig{ - AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3}, - }, - }) - - svc.bootstrap(context.Background()) - - require.Equal(t, []string{PlatformAntigravity}, reader.platforms) - - require.Len(t, backend.getCalls, 1) - require.Equal(t, fakeLSPoolGetCall{ - AccountID: "101", - RoutingKey: "", - ProxyURL: "socks5h://alice:secret@127.0.0.1:1080", - }, backend.getCalls[0]) - - tokenCall, ok := backend.tokenCalls["101"] - require.True(t, ok) - require.Equal(t, "token-101", tokenCall.AccessToken) - require.Equal(t, "refresh-101", tokenCall.RefreshToken) - require.Equal(t, expiresAt, tokenCall.ExpiresAt) - - creditCall, ok := backend.creditCalls["101"] - require.True(t, ok) - require.True(t, creditCall.UseAICredits) - require.NotNil(t, creditCall.AvailableCredits) - require.Equal(t, int32(120), *creditCall.AvailableCredits) - require.NotNil(t, creditCall.MinimumCreditAmount) - require.Equal(t, int32(55), *creditCall.MinimumCreditAmount) - - require.NotContains(t, backend.tokenCalls, "102") - require.NotContains(t, backend.tokenCalls, "103") - require.NotContains(t, backend.tokenCalls, "104") - require.NotContains(t, backend.tokenCalls, "106") -} - -func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) { - reader := &fakeLSBootstrapAccountReader{ - accounts: []Account{ - { - ID: 201, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"}, - }, - { - ID: 202, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"}, - }, - }, - } - backend := newFakeLSPoolBackend() - backend.getErrs["201"] = errors.New("create failed") - - svc := NewLSPoolBootstrapService(reader, backend, &config.Config{}) - svc.bootstrap(context.Background()) - - require.Len(t, backend.getCalls, 2) - require.Contains(t, backend.tokenCalls, "201") - require.Contains(t, backend.tokenCalls, "202") -} diff --git a/deploy/ls-bin/cert.pem b/deploy/ls-bin/cert.pem deleted file mode 100644 index de9c01de..00000000 --- a/deploy/ls-bin/cert.pem +++ /dev/null @@ -1,21 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL -BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy -MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN -MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO -QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ -KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM -lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw -4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA -onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5 -/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD -k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9 -MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt -yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/ -Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh -bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk -Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt -q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai -KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6 -/Q== ------END CERTIFICATE----- diff --git a/deploy/ls-bin/language_server_linux_arm b/deploy/ls-bin/language_server_linux_arm deleted file mode 100755 index 5463ca16..00000000 Binary files a/deploy/ls-bin/language_server_linux_arm and /dev/null differ diff --git a/deploy/ls-bin/language_server_linux_x64 b/deploy/ls-bin/language_server_linux_x64 deleted file mode 100755 index 79131863..00000000 Binary files a/deploy/ls-bin/language_server_linux_x64 and /dev/null differ diff --git a/deploy/lsworker-entrypoint.sh b/deploy/lsworker-entrypoint.sh deleted file mode 100644 index 215058c3..00000000 --- a/deploy/lsworker-entrypoint.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/sh -set -eu - -PROXY_HOST="${LSWORKER_PROXY_HOST:-}" -PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}" -PROXY_USER="${LSWORKER_PROXY_USER:-}" -PROXY_PASS="${LSWORKER_PROXY_PASS:-}" -CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}" -REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}" -NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}" - -mkdir -p "$(dirname "${NETWORK_READY_FILE}")" - -if [ -z "${PROXY_HOST}" ]; then - echo "LSWORKER_PROXY_HOST is required" >&2 - exit 1 -fi - -PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')" -if [ -z "${PROXY_IP}" ]; then - echo "failed to resolve proxy host: ${PROXY_HOST}" >&2 - exit 1 -fi - -cat >/tmp/redsocks.conf <>/tmp/redsocks.conf -fi -if [ -n "${PROXY_PASS}" ]; then - printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf -fi - -cat >>/tmp/redsocks.conf </tmp/redsocks.log 2>&1 & -REDSOCKS_PID="$!" -trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT - -sleep 1 - -iptables -t nat -N REDSOCKS 2>/dev/null || true -iptables -t nat -F REDSOCKS -iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN -iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN -iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN -iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN -iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}" -iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true -iptables -t nat -A OUTPUT -p tcp -j REDSOCKS - -touch "${NETWORK_READY_FILE}" - -exec gosu sub2api /app/lsworker diff --git a/deploy/lsworker.Dockerfile b/deploy/lsworker.Dockerfile deleted file mode 100644 index 49178874..00000000 --- a/deploy/lsworker.Dockerfile +++ /dev/null @@ -1,52 +0,0 @@ -ARG GOLANG_IMAGE=golang:1.26.1-alpine -ARG DEBIAN_IMAGE=debian:bookworm-slim - -FROM ${GOLANG_IMAGE} AS builder - -WORKDIR /app/backend -RUN apk add --no-cache git ca-certificates tzdata - -COPY backend/go.mod backend/go.sum ./ -RUN go mod download - -COPY backend/ ./ -RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker - -FROM ${DEBIAN_IMAGE} - -RUN apt-get update && apt-get install -y --no-install-recommends \ - ca-certificates \ - curl \ - gosu \ - iproute2 \ - iptables \ - redsocks \ - tzdata \ - && rm -rf /var/lib/apt/lists/* - -RUN groupadd -g 1000 sub2api && \ - useradd -u 1000 -g sub2api -m -s /bin/sh sub2api - -WORKDIR /app - -COPY --from=builder /app/lsworker /app/lsworker -COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/ -COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/ - -ARG TARGETARCH -RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \ - if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \ - cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \ - else \ - cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \ - fi && \ - chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \ - chown -R sub2api:sub2api /app /run/lsworker && \ - rm -rf /tmp/ls-bin - -COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh -RUN chmod +x /app/lsworker-entrypoint.sh - -EXPOSE 18081 - -ENTRYPOINT ["/app/lsworker-entrypoint.sh"]