diff --git a/Dockerfile b/Dockerfile index a16eb958..368e56fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.26.1-alpine ARG ALPINE_IMAGE=alpine:3.21 +ARG DEBIAN_IMAGE=debian:bookworm-slim ARG POSTGRES_IMAGE=postgres:18-alpine ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -63,10 +64,12 @@ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist # Build the binary (BuildType=release for CI builds, embed frontend) # Version precedence: build arg VERSION > cmd/server/VERSION +ARG TARGETARCH +ARG TARGETOS=linux RUN VERSION_VALUE="${VERSION}" && \ if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \ DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \ - CGO_ENABLED=0 GOOS=linux go build \ + CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build \ -tags embed \ -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ -trimpath \ @@ -79,9 +82,9 @@ RUN VERSION_VALUE="${VERSION}" && \ FROM ${POSTGRES_IMAGE} AS pg-client # ----------------------------------------------------------------------------- -# Stage 4: Final Runtime Image +# Stage 4: Final Runtime Image (Debian for glibc — LS binary requires it) # ----------------------------------------------------------------------------- -FROM ${ALPINE_IMAGE} +FROM ${DEBIAN_IMAGE} # Labels LABEL maintainer="Wei-Shaw " @@ -89,27 +92,25 @@ LABEL description="Sub2API - AI API Gateway Platform" LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" # Install runtime dependencies -RUN apk add --no-cache \ +RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ + curl \ + wget \ + gosu \ + proxychains4 \ tzdata \ - su-exec \ - libpq \ - zstd-libs \ - lz4-libs \ - krb5-libs \ - libldap \ - libedit \ - && rm -rf /var/cache/apk/* + libpq5 \ + && rm -rf /var/lib/apt/lists/* # Copy pg_dump and psql from the same postgres image used in docker-compose -# This ensures version consistency between backup tools and the database server COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/ +RUN ldconfig # Create non-root user -RUN addgroup -g 1000 sub2api && \ - adduser -u 1000 -G sub2api -s /bin/sh -D sub2api +RUN groupadd -g 1000 sub2api && \ + useradd -u 1000 -g sub2api -m -s /bin/sh sub2api # Set working directory WORKDIR /app @@ -118,6 +119,21 @@ WORKDIR /app COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources +# Copy Language Server binary and cert (for LS pool mode) +# Enable with: ANTIGRAVITY_LS_MODE=true ANTIGRAVITY_APP_ROOT=/app/ls +# TARGETARCH is set automatically by buildx (amd64 or arm64) +ARG TARGETARCH +COPY --chown=sub2api:sub2api deploy/ls-bin/language_server_linux_* /tmp/ls-bin/ +COPY --chown=sub2api:sub2api deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/ +RUN mkdir -p /app/ls/extensions/antigravity/bin && \ + if [ "$TARGETARCH" = "arm64" ]; then \ + cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \ + else \ + cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \ + fi && \ + chmod +x /app/ls/extensions/antigravity/bin/language_server_linux_* && \ + rm -rf /tmp/ls-bin + # Create data directory RUN mkdir -p /app/data && chown sub2api:sub2api /app/data diff --git a/backend/cmd/lsworker/main.go b/backend/cmd/lsworker/main.go new file mode 100644 index 00000000..deeb0649 --- /dev/null +++ b/backend/cmd/lsworker/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "errors" + "log/slog" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" +) + +func main() { + server, err := lspool.NewWorkerServerFromEnv() + if err != nil { + slog.Error("failed to initialize lsworker", "err", err) + os.Exit(1) + } + defer server.Close() + + httpServer := &http.Server{ + Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"), + Handler: server.Handler(), + ReadHeaderTimeout: 10 * 1e9, + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + go func() { + <-ctx.Done() + _ = httpServer.Shutdown(context.Background()) + }() + + slog.Info("lsworker listening", "addr", httpServer.Addr) + if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Error("lsworker exited with error", "err", err) + os.Exit(1) + } +} + +func envOrDefault(key, fallback string) string { + if value := os.Getenv(key); value != "" { + return value + } + return fallback +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3ee5d6cd..49669604 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -379,6 +379,8 @@ type GatewayConfig struct { OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` + // AntigravityLSWorker: LS worker 容器控制平面配置 + AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -469,6 +471,16 @@ type GatewayConfig struct { UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` } +type GatewayAntigravityLSWorkerConfig struct { + Image string `mapstructure:"image"` + Network string `mapstructure:"network"` + DockerSocket string `mapstructure:"docker_socket"` + IdleTTL time.Duration `mapstructure:"idle_ttl"` + MaxActive int `mapstructure:"max_active"` + StartupTimeout time.Duration `mapstructure:"startup_timeout"` + RequestTimeout time.Duration `mapstructure:"request_timeout"` +} + // UserMessageQueueConfig 用户消息串行队列配置 // 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 type UserMessageQueueConfig struct { @@ -1278,6 +1290,15 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) + + // Gateway LS worker + viper.SetDefault("gateway.antigravity_ls_worker.image", "weishaw/sub2api-lsworker:latest") + viper.SetDefault("gateway.antigravity_ls_worker.network", "sub2api-network") + viper.SetDefault("gateway.antigravity_ls_worker.docker_socket", "unix:///var/run/docker.sock") + viper.SetDefault("gateway.antigravity_ls_worker.idle_ttl", 15*time.Minute) + viper.SetDefault("gateway.antigravity_ls_worker.max_active", 50) + viper.SetDefault("gateway.antigravity_ls_worker.startup_timeout", 45*time.Second) + viper.SetDefault("gateway.antigravity_ls_worker.request_timeout", 60*time.Second) viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index fdd7fea1..eeb59bdd 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -440,7 +440,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" - reqBody.Metadata.IDEVersion = "1.20.6" + reqBody.Metadata.IDEVersion = "1.107.0" reqBody.Metadata.IDEName = "antigravity" bodyBytes, err := json.Marshal(reqBody) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 8a8bed92..4b042005 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "runtime" "strings" "sync" "time" @@ -49,8 +50,8 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5 -var defaultUserAgentVersion = "1.20.5" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0 +var defaultUserAgentVersion = "1.107.0" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 var defaultClientSecret = "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" @@ -66,9 +67,9 @@ func init() { } } -// GetUserAgent 返回当前配置的 User-Agent +// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为) func GetUserAgent() string { - return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion) + return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH) } func getClientSecret() (string, error) { diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 1b45e507..98a39c30 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -39,6 +39,34 @@ func generateStableSessionID(contents []GeminiContent) string { return "-" + strconv.FormatInt(n, 10) } +// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it. +// preferredSessionID wins; otherwise we derive a stable value from the first user turn. +func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + + if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" { + return body, nil + } + + sessionID := strings.TrimSpace(preferredSessionID) + if sessionID == "" { + var req GeminiRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + sessionID = generateStableSessionID(req.Contents) + } + if sessionID == "" { + return body, nil + } + + payload["sessionId"] = sessionID + return json.Marshal(payload) +} + type TransformOptions struct { EnableIdentityPatch bool // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 9e46295a..aaf8d72a 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -8,6 +8,43 @@ import ( "github.com/stretchr/testify/require" ) +func TestEnsureGeminiRequestSessionID(t *testing.T) { + t.Run("prefers provided session id", func(t *testing.T) { + body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) + updated, err := EnsureGeminiRequestSessionID(body, "session-from-header") + require.NoError(t, err) + + var payload map[string]any + require.NoError(t, json.Unmarshal(updated, &payload)) + require.Equal(t, "session-from-header", payload["sessionId"]) + }) + + t.Run("keeps existing session id", func(t *testing.T) { + body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) + updated, err := EnsureGeminiRequestSessionID(body, "session-from-header") + require.NoError(t, err) + + var payload map[string]any + require.NoError(t, json.Unmarshal(updated, &payload)) + require.Equal(t, "session-in-body", payload["sessionId"]) + }) + + t.Run("derives stable fallback from contents", func(t *testing.T) { + body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) + first, err := EnsureGeminiRequestSessionID(body, "") + require.NoError(t, err) + second, err := EnsureGeminiRequestSessionID(body, "") + require.NoError(t, err) + + var firstPayload map[string]any + var secondPayload map[string]any + require.NoError(t, json.Unmarshal(first, &firstPayload)) + require.NoError(t, json.Unmarshal(second, &secondPayload)) + require.NotEmpty(t, firstPayload["sessionId"]) + require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"]) + }) +} + // TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { tests := []struct { diff --git a/backend/internal/pkg/lspool/backend.go b/backend/internal/pkg/lspool/backend.go new file mode 100644 index 00000000..ba77e3d7 --- /dev/null +++ b/backend/internal/pkg/lspool/backend.go @@ -0,0 +1,13 @@ +package lspool + +import "time" + +// Backend is the control-plane abstraction used by the HTTP upstream wrapper. +// It may be backed by a local in-process Pool or by remote LS workers. +type Backend interface { + GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) + SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) + SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) + Stats() map[string]any + Close() +} diff --git a/backend/internal/pkg/lspool/global.go b/backend/internal/pkg/lspool/global.go new file mode 100644 index 00000000..926f69d7 --- /dev/null +++ b/backend/internal/pkg/lspool/global.go @@ -0,0 +1,94 @@ +// Package lspool provides LS-mode integration for the antigravity gateway. +// +// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to +// streamGenerateContent are routed through a real Language Server instance +// instead of directly to cloudcode-pa. This provides: +// +// - Authentic TLS fingerprint (Google's own Go binary) +// - Real session management and Heartbeat +// - Indistinguishable from a real IDE instance +// +// To enable: set environment variable ANTIGRAVITY_LS_MODE=true +// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path +package lspool + +import ( + "log/slog" + "os" + "strings" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + globalBackend Backend + globalPoolOnce sync.Once + lsModeEnabled bool +) + +func init() { + lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true" +} + +// IsLSModeEnabled returns whether LS mode is active +func IsLSModeEnabled() bool { + return lsModeEnabled +} + +const ( + LSStrategyDirect = "direct" + LSStrategyJSParity = "js-parity" +) + +// CurrentLSStrategy returns the active LS routing strategy. +// Unknown values are treated as "direct" for safety. +func CurrentLSStrategy() string { + switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) { + case "", LSStrategyDirect: + return LSStrategyDirect + case LSStrategyJSParity: + return LSStrategyJSParity + default: + return LSStrategyDirect + } +} + +// GlobalPool returns the singleton LS pool instance +// Creates it on first call if LS mode is enabled +func GlobalPool(cfg *config.Config) Backend { + if !lsModeEnabled { + return nil + } + globalPoolOnce.Do(func() { + manager, err := NewWorkerManagerFromConfig(cfg) + if err != nil { + slog.Default().Error("failed to initialize LS worker manager", "err", err) + return + } + globalBackend = manager + }) + return globalBackend +} + +// Shutdown closes the global pool +func Shutdown() { + if globalBackend != nil { + globalBackend.Close() + } +} + +// StatusInfo returns the current LS pool status for diagnostics +func StatusInfo() map[string]any { + info := map[string]any{ + "ls_mode_enabled": lsModeEnabled, + "build": "enhanced", + "user_agent": "antigravity/1.107.0", + } + if lsModeEnabled && globalBackend != nil { + stats := globalBackend.Stats() + info["pool_total"] = stats["total"] + info["pool_active"] = stats["active"] + } + return info +} diff --git a/backend/internal/pkg/lspool/integration_test.go b/backend/internal/pkg/lspool/integration_test.go new file mode 100644 index 00000000..4281d66c --- /dev/null +++ b/backend/internal/pkg/lspool/integration_test.go @@ -0,0 +1,794 @@ +package lspool + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func readConnectFrame(r io.Reader) ([]byte, error) { + header := make([]byte, 5) + if _, err := io.ReadFull(r, header); err != nil { + return nil, err + } + payloadLen := binary.BigEndian.Uint32(header[1:5]) + payload := make([]byte, payloadLen) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + return payload, nil +} + +func decodeProtoBytesField(data []byte, targetField int) []byte { + i := 0 + for i < len(data) { + tag, n := binary.Uvarint(data[i:]) + if n <= 0 { + return nil + } + i += n + fieldNum := int(tag >> 3) + wireType := tag & 0x7 + switch wireType { + case 0: + _, n = binary.Uvarint(data[i:]) + if n <= 0 { + return nil + } + i += n + case 2: + length, n := binary.Uvarint(data[i:]) + if n <= 0 { + return nil + } + i += n + if i+int(length) > len(data) { + return nil + } + if fieldNum == targetField { + return data[i : i+int(length)] + } + i += int(length) + case 1: + i += 8 + case 5: + i += 4 + default: + return nil + } + } + return nil +} + +// TestMockExtensionServerTokenInjection verifies the token injection flow: +// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo +func TestMockExtensionServerTokenInjection(t *testing.T) { + csrf := "test-csrf-token" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + // 1. Set token for an account + srv.SetToken("account-1", &TokenInfo{ + AccessToken: "ya29.test-access-token", + RefreshToken: "1//test-refresh-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + + // 2. Verify token is stored + srv.mu.RLock() + info, ok := srv.tokens["account-1"] + srv.mu.RUnlock() + require.True(t, ok) + require.Equal(t, "ya29.test-access-token", info.AccessToken) + require.Equal(t, "1//test-refresh-token", info.RefreshToken) + require.False(t, info.ExpiresAt.IsZero()) + + // 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server) + req, _ := http.NewRequest("POST", + fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), + bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth")))) + req.Header.Set("x-codeium-csrf-token", csrf) + req.Header.Set("Content-Type", "application/connect+proto") + + // The stream handler will block, so run in background and cancel after we confirm connection + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + req = req.WithContext(ctx) + + client := &http.Client{} + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) + require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type")) + + // Read the first envelope frame (initial state) + header := make([]byte, 5) + n, readErr := resp.Body.Read(header) + if readErr == nil && n == 5 { + require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)") + t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5]) + } + } +} + +// TestMockExtensionServerCSRF verifies CSRF token validation +func TestMockExtensionServerCSRF(t *testing.T) { + csrf := "correct-csrf" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port()) + + // Wrong CSRF → 403 + req, _ := http.NewRequest("POST", base, nil) + req.Header.Set("x-codeium-csrf-token", "wrong-csrf") + req.Header.Set("Content-Type", "application/proto") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 403, resp.StatusCode) + + // Correct CSRF → 200 + req2, _ := http.NewRequest("POST", base, nil) + req2.Header.Set("x-codeium-csrf-token", csrf) + req2.Header.Set("Content-Type", "application/proto") + resp2, err := http.DefaultClient.Do(req2) + require.NoError(t, err) + defer resp2.Body.Close() + require.Equal(t, 200, resp2.StatusCode) +} + +// TestMockExtensionServerGetSecretValue verifies the fallback token path +func TestMockExtensionServerGetSecretValue(t *testing.T) { + csrf := "test-csrf" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"}) + + // GetSecretValue should return the token + req, _ := http.NewRequest("POST", + fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()), + nil) + req.Header.Set("x-codeium-csrf-token", csrf) + req.Header.Set("Content-Type", "application/proto") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) +} + +// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format +func TestOAuthTokenInfoProto(t *testing.T) { + expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC) + bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry) + + // Verify fields are present by checking proto wire format + require.True(t, len(bin) > 0, "proto should not be empty") + + // Field 1 (access_token): tag=0x0a, value="ya29.test" + require.Contains(t, string(bin), "ya29.test") + // Field 2 (token_type): tag=0x12, value="Bearer" + require.Contains(t, string(bin), "Bearer") + // Field 3 (refresh_token): tag=0x1a, value="1//refresh" + require.Contains(t, string(bin), "1//refresh") + + // Without refresh_token + binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry) + require.NotContains(t, string(binNoRefresh), "1//refresh") +} + +// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded +func TestOAuthTokenInfoWithRealExpiry(t *testing.T) { + future := time.Now().Add(2 * time.Hour) + bin := buildOAuthTokenInfoBinary("token", "refresh", future) + + // Zero expiry should default to ~1h + binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{}) + + // They should be different lengths or content (different expiry timestamps) + // Both should be valid (non-empty) + require.True(t, len(bin) > 0) + require.True(t, len(binZero) > 0) +} + +// TestUSSTopicWithOAuth verifies the full USS topic proto structure +func TestUSSTopicWithOAuth(t *testing.T) { + expiry := time.Now().Add(1 * time.Hour) + topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry) + + require.True(t, len(topic) > 0) + // The topic should contain the sentinel key + require.Contains(t, string(topic), "oauthTokenInfoSentinelKey") +} + +func TestUSSTopicWithModelCredits(t *testing.T) { + available := int32(123) + minimum := int32(50) + topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{ + UseAICredits: true, + AvailableCredits: &available, + MinimumCreditAmountForUsage: &minimum, + }) + + require.True(t, len(topic) > 0) + require.Contains(t, string(topic), useAICreditsSentinelKey) + require.Contains(t, string(topic), availableCreditsSentinelKey) + require.Contains(t, string(topic), minimumCreditAmountForUsageKey) +} + +func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) { + csrf := "test-csrf-token" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + srv.SetModelCredits("account-1", &ModelCreditsInfo{}) + + req, _ := http.NewRequest("POST", + fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()), + bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits")))) + req.Header.Set("x-codeium-csrf-token", csrf) + req.Header.Set("Content-Type", "application/connect+proto") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + req = req.WithContext(ctx) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) + + // Drain the initial_state frame first. + _, err = readConnectFrame(resp.Body) + require.NoError(t, err) + + available := int32(77) + minimum := int32(25) + srv.SetModelCredits("account-1", &ModelCreditsInfo{ + UseAICredits: true, + AvailableCredits: &available, + MinimumCreditAmountForUsage: &minimum, + }) + + keys := make([]string, 0, 3) + for len(keys) < 3 { + frame, readErr := readConnectFrame(resp.Body) + require.NoError(t, readErr) + applied := decodeProtoBytesField(frame, 2) + require.NotEmpty(t, applied) + keys = append(keys, decodeProtoString(applied, 1)) + } + + require.Contains(t, keys, useAICreditsSentinelKey) + require.Contains(t, keys, availableCreditsSentinelKey) + require.Contains(t, keys, minimumCreditAmountForUsageKey) +} + +// TestBuildInitialStateUpdate verifies the USS update wrapper +func TestBuildInitialStateUpdate(t *testing.T) { + topicData := buildEmptyTopic() + update := buildInitialStateUpdate(topicData) + // Should be a valid proto bytes field (field 1 = initial_state) + require.True(t, len(update) >= 0) // empty topic is valid + + topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour)) + update2 := buildInitialStateUpdate(topicData2) + require.True(t, len(update2) > len(update), "non-empty topic should produce larger update") +} + +// TestPoolSetAccountTokenComplete verifies pool accepts full credential set +func TestPoolSetAccountTokenComplete(t *testing.T) { + csrf := "pool-csrf" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pool := &Pool{ + config: DefaultConfig(), + instances: make(map[string][]*Instance), + extServer: srv, + ctx: ctx, + cancel: cancel, + } + + expiry := time.Now().Add(1 * time.Hour) + pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry) + + srv.mu.RLock() + info := srv.tokens["acc-1"] + srv.mu.RUnlock() + + require.NotNil(t, info) + require.Equal(t, "ya29.full-token", info.AccessToken) + require.Equal(t, "1//full-refresh", info.RefreshToken) + require.False(t, info.ExpiresAt.IsZero()) + require.WithinDuration(t, expiry, info.ExpiresAt, time.Second) +} + +func TestPoolSetAccountModelCreditsComplete(t *testing.T) { + csrf := "pool-csrf" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pool := &Pool{ + config: DefaultConfig(), + instances: make(map[string][]*Instance), + extServer: srv, + ctx: ctx, + cancel: cancel, + } + + available := int32(77) + minimum := int32(25) + pool.SetAccountModelCredits("acc-1", true, &available, &minimum) + + srv.mu.RLock() + info := srv.credits["acc-1"] + srv.mu.RUnlock() + + require.NotNil(t, info) + require.True(t, info.UseAICredits) + require.NotNil(t, info.AvailableCredits) + require.Equal(t, available, *info.AvailableCredits) + require.NotNil(t, info.MinimumCreditAmountForUsage) + require.Equal(t, minimum, *info.MinimumCreditAmountForUsage) +} + +// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped. +func TestUpstreamAdapterExtractsCredentials(t *testing.T) { + // Create a mock upstream that records what it receives + var receivedHeaders http.Header + var mu sync.Mutex + fallback := &recordingUpstreamWithCallback{} + fallback.onDo = func(req *http.Request) { + mu.Lock() + receivedHeaders = req.Header.Clone() + mu.Unlock() + } + + csrf := "test-csrf" + srv, err := NewMockExtensionServer(csrf) + require.NoError(t, err) + defer srv.Close() + + pool := &Pool{ + config: DefaultConfig(), + instances: make(map[string][]*Instance), + extServer: srv, + } + + upstream := NewLSPoolUpstream(pool, fallback) + + // Non-streamGenerateContent request → should pass through to fallback + req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil) + req.Header.Set("Authorization", "Bearer ya29.test") + req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh") + req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z") + req.Header.Set(useAICreditsHeader, "true") + req.Header.Set(availableCreditsHeader, "42") + req.Header.Set(minimumCreditAmountHeader, "50") + + resp, err := upstream.Do(req, "", 1, 1) + require.NoError(t, err) + require.NotNil(t, resp) + + // Internal headers should never leak to the direct upstream. + mu.Lock() + require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token")) + require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry")) + require.Empty(t, receivedHeaders.Get(useAICreditsHeader)) + require.Empty(t, receivedHeaders.Get(availableCreditsHeader)) + require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader)) + mu.Unlock() + + srv.mu.RLock() + tokenInfo := srv.tokens["1"] + creditsInfo := srv.credits["1"] + srv.mu.RUnlock() + + require.NotNil(t, tokenInfo) + require.Equal(t, "ya29.test", tokenInfo.AccessToken) + require.NotNil(t, creditsInfo) + require.True(t, creditsInfo.UseAICredits) + require.NotNil(t, creditsInfo.AvailableCredits) + require.Equal(t, int32(42), *creditsInfo.AvailableCredits) + require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage) + require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage) +} + +// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction +func TestExtractPromptAndModelMultiTurn(t *testing.T) { + body := `{ + "model": "claude-sonnet-4-6", + "request": { + "systemInstruction": {"parts": [{"text": "You are helpful"}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there!"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + } + }` + prompt, model := extractPromptAndModel([]byte(body)) + require.Equal(t, "claude-sonnet-4-6", model) + require.Contains(t, prompt, "You are helpful") + require.Contains(t, prompt, "Hello") + require.Contains(t, prompt, "Hi there!") + require.Contains(t, prompt, "How are you?") +} + +// TestExtractUsageFromTrajectory verifies token usage extraction +func TestExtractUsageFromTrajectory(t *testing.T) { + resp := `{ + "trajectory": { + "steps": [{ + "type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE", + "status": "CORTEX_STEP_STATUS_DONE", + "plannerResponse": {"response": "OK"}, + "metadata": { + "modelUsage": { + "inputTokens": "150", + "outputTokens": "5" + } + } + }] + } + }` + usage := extractUsageFromTrajectory([]byte(resp)) + require.NotNil(t, usage) + require.Equal(t, 150, usage["promptTokenCount"]) + require.Equal(t, 5, usage["candidatesTokenCount"]) + require.Equal(t, 155, usage["totalTokenCount"]) +} + +// TestSSEChunkFormat verifies the Gemini SSE output format +func TestSSEChunkFormat(t *testing.T) { + chunk := buildGeminiSSEChunk("Hello world") + require.True(t, len(chunk) > 0) + require.Contains(t, chunk, "data: ") + require.Contains(t, chunk, `"text":"Hello world"`) + require.Contains(t, chunk, `"role":"model"`) + require.True(t, chunk[len(chunk)-2:] == "\n\n") + + // Verify it's valid JSON after stripping "data: " prefix + jsonStr := chunk[len("data: ") : len(chunk)-2] + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err) + response := parsed["response"].(map[string]any) + candidates := response["candidates"].([]any) + require.Len(t, candidates, 1) +} + +// TestSSEFinalChunkFormat verifies the final SSE chunk with usage +func TestSSEFinalChunkFormat(t *testing.T) { + usage := map[string]any{ + "promptTokenCount": 100, + "candidatesTokenCount": 50, + "totalTokenCount": 150, + } + chunk := buildGeminiSSEFinalChunk(usage) + require.Contains(t, chunk, "data: ") + require.Contains(t, chunk, `"finishReason":"STOP"`) + require.Contains(t, chunk, `"usageMetadata"`) +} + +func TestStreamCascadeResponsePollsImmediately(t *testing.T) { + var getCalls atomic.Int32 + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) + + if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") { + getCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) + return + } + + http.NotFound(w, r) + })) + defer server.Close() + + inst := &Instance{ + AccountID: "42", + CSRF: "test-csrf", + Address: strings.TrimPrefix(server.URL, "https://"), + client: server.Client(), + healthy: true, + lastUsed: time.Now(), + } + upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{}) + + pr, pw := io.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil) + close(done) + }() + + body, err := io.ReadAll(pr) + require.NoError(t, err) + <-done + + require.GreaterOrEqual(t, getCalls.Load(), int32(1)) + require.Contains(t, string(body), "hello from ls") +} + +// TestRequestHasToolsEdgeCases verifies tool detection edge cases +func TestRequestHasToolsEdgeCases(t *testing.T) { + // null tools + require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`))) + // tools with empty function declarations + require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`))) + // deeply nested wrapped format + require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`))) +} + +func TestJSParityRouteReusesCascadeSession(t *testing.T) { + t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) + + var startCalls atomic.Int32 + var sendCalls atomic.Int32 + var getCalls atomic.Int32 + var sendBodiesMu sync.Mutex + var sendBodies []map[string]any + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) + + switch { + case strings.HasSuffix(r.URL.Path, "/StartCascade"): + startCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) + case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): + sendCalls.Add(1) + var payload map[string]any + err := json.NewDecoder(r.Body).Decode(&payload) + require.NoError(t, err) + sendBodiesMu.Lock() + sendBodies = append(sendBodies, payload) + sendBodiesMu.Unlock() + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"queued":false}`)) + case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): + call := getCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + text := "hello from ls" + if call > 1 { + text = "follow up from ls" + } + _, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text))) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + inst := &Instance{ + AccountID: "42", + CSRF: "test-csrf", + Address: strings.TrimPrefix(server.URL, "https://"), + client: server.Client(), + healthy: true, + lastUsed: time.Now(), + } + inst.SetModelMappingReady(true) + pool := &Pool{ + config: Config{ReplicasPerAccount: 1}, + instances: map[string][]*Instance{"42": []*Instance{inst}}, + } + upstream := NewLSPoolUpstream(pool, &recordingUpstream{}) + + req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) + req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) + require.NoError(t, err) + req1.Header.Set("Authorization", "Bearer downstream-a") + + resp1, err := upstream.Do(req1, "", 42, 1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + require.Contains(t, string(body1), `"text":"hello from ls"`) + + req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) + req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) + require.NoError(t, err) + req2.Header.Set("Authorization", "Bearer downstream-a") + + resp2, err := upstream.Do(req2, "", 42, 1) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + require.Contains(t, string(body2), `"text":"follow up from ls"`) + + require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript") + require.Equal(t, int32(2), sendCalls.Load()) + + sendBodiesMu.Lock() + require.Len(t, sendBodies, 2) + firstSend := sendBodies[0] + sendBodiesMu.Unlock() + + require.Equal(t, "cid-1", firstSend["cascadeId"]) + require.Equal(t, false, firstSend["blocking"]) + metadata, ok := firstSend["metadata"].(map[string]any) + require.True(t, ok) + require.Equal(t, "antigravity", metadata["ideName"]) + require.Equal(t, "1.107.0", metadata["ideVersion"]) + require.NotContains(t, firstSend, "clientType") + require.NotContains(t, firstSend, "messageOrigin") + cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any) + require.True(t, ok) + plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any) + require.True(t, ok) + requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) + require.True(t, ok) + require.NotEmpty(t, requestedModel["model"]) + require.Len(t, plannerConfig, 1) + require.Len(t, cascadeConfig, 1) +} + +func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) { + t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) + + var startCalls atomic.Int32 + var sendCalls atomic.Int32 + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token")) + + switch { + case strings.HasSuffix(r.URL.Path, "/StartCascade"): + startCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) + case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): + sendCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"queued":false}`)) + case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + inst := &Instance{ + AccountID: "42", + CSRF: "test-csrf", + Address: strings.TrimPrefix(server.URL, "https://"), + client: server.Client(), + healthy: true, + lastUsed: time.Now(), + } + inst.SetModelMappingReady(true) + fallback := &recordingUpstream{} + pool := &Pool{ + config: Config{ReplicasPerAccount: 1}, + instances: map[string][]*Instance{"42": []*Instance{inst}}, + } + upstream := NewLSPoolUpstream(pool, fallback) + + req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) + req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body)) + require.NoError(t, err) + req1.Header.Set("Authorization", "Bearer downstream-a") + + resp1, err := upstream.Do(req1, "", 42, 1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + require.Contains(t, string(body1), `"text":"hello from ls"`) + + req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`) + req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body)) + require.NoError(t, err) + req2.Header.Set("Authorization", "Bearer downstream-a") + + resp2, err := upstream.Do(req2, "", 42, 1) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + + require.Equal(t, "ok", string(body2)) + require.Equal(t, 1, fallback.doCalls) + require.Equal(t, int32(1), startCalls.Load()) + require.Equal(t, int32(1), sendCalls.Load()) +} + +func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) { + t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity) + + var startCalls atomic.Int32 + var sendCalls atomic.Int32 + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/StartCascade"): + startCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`)) + case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"): + sendCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"queued":false}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + inst := &Instance{ + AccountID: "42", + CSRF: "test-csrf", + Address: strings.TrimPrefix(server.URL, "https://"), + client: server.Client(), + healthy: true, + lastUsed: time.Now(), + } + + fallback := &recordingUpstream{} + pool := &Pool{ + config: Config{ReplicasPerAccount: 1}, + instances: map[string][]*Instance{"42": []*Instance{inst}}, + } + upstream := NewLSPoolUpstream(pool, fallback) + + reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) + req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody)) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer downstream-a") + + resp, err := upstream.Do(req, "", 42, 1) + require.Nil(t, resp) + require.ErrorIs(t, err, errLSModelMapPending) + require.Equal(t, int32(0), startCalls.Load()) + require.Equal(t, int32(0), sendCalls.Load()) + require.Equal(t, 0, fallback.doCalls) +} + +// recordingUpstreamWithCallback extends the base recordingUpstream with a callback +type recordingUpstreamWithCallback struct { + recordingUpstream + onDo func(req *http.Request) +} + +func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) { + if r.onDo != nil { + r.onDo(req) + } + return r.recordingUpstream.Do(req, proxyURL, accountID, c) +} diff --git a/backend/internal/pkg/lspool/mock_extension_server.go b/backend/internal/pkg/lspool/mock_extension_server.go new file mode 100644 index 00000000..b5f3702a --- /dev/null +++ b/backend/internal/pkg/lspool/mock_extension_server.go @@ -0,0 +1,908 @@ +// Package lspool provides a mock Extension Server that the LS binary connects +// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server +// using connectNodeAdapter. We replicate that protocol here. +// +// Protocol details (from extension.js source): +// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS) +// - Auth: x-codeium-csrf-token header on every request +// - Unary request Content-Type: application/proto (binary protobuf, no envelope) +// OR application/connect+proto (with 5-byte envelope) +// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope) +// - Stream request Content-Type: application/connect+proto (with 5-byte envelope) +// - Stream response Content-Type: application/connect+proto (envelope-framed messages) +// +// The LS sends requests with content-type "application/connect+proto" for BOTH +// unary and streaming RPCs. ConnectRPC's content-type regex: +// +// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i +// +// If "connect+" prefix is present → stream mode; otherwise → unary mode. +// However the LS Go client uses the Connect protocol client which always sends +// "application/proto" for unary and "application/connect+proto" for streaming. +// +// We detect the RPC kind from the URL path and respond accordingly. +package lspool + +import ( + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "strings" + "sync" + "time" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ============================================================ +// Proto helpers — hand-encode minimal proto messages so we don't +// need to import the full generated proto package. +// ============================================================ + +// encodeProtoString writes a proto string field (wire type 2) to a byte slice. +func encodeProtoString(fieldNum int, val string) []byte { + tag := encodeVarint(uint64(fieldNum<<3 | 2)) + length := encodeVarint(uint64(len(val))) + out := make([]byte, 0, len(tag)+len(length)+len(val)) + out = append(out, tag...) + out = append(out, length...) + out = append(out, []byte(val)...) + return out +} + +// encodeProtoBytes writes a proto bytes/message field (wire type 2). +func encodeProtoBytes(fieldNum int, val []byte) []byte { + tag := encodeVarint(uint64(fieldNum<<3 | 2)) + length := encodeVarint(uint64(len(val))) + out := make([]byte, 0, len(tag)+len(length)+len(val)) + out = append(out, tag...) + out = append(out, length...) + out = append(out, val...) + return out +} + +// encodeProtoVarint writes a proto varint field (wire type 0). +func encodeProtoVarint(fieldNum int, val uint64) []byte { + tag := encodeVarint(uint64(fieldNum<<3 | 0)) + v := encodeVarint(val) + out := make([]byte, 0, len(tag)+len(v)) + out = append(out, tag...) + out = append(out, v...) + return out +} + +// encodeProtoBool writes a proto bool field. +func encodeProtoBool(fieldNum int, val bool) []byte { + v := uint64(0) + if val { + v = 1 + } + return encodeProtoVarint(fieldNum, v) +} + +func encodeVarint(v uint64) []byte { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, v) + return buf[:n] +} + +// decodeProtoString extracts a string field from raw proto bytes. +func decodeProtoString(data []byte, targetField int) string { + i := 0 + for i < len(data) { + if i >= len(data) { + break + } + tag, n := binary.Uvarint(data[i:]) + if n <= 0 { + break + } + i += n + fieldNum := int(tag >> 3) + wireType := tag & 0x7 + + switch wireType { + case 0: // varint + _, n = binary.Uvarint(data[i:]) + if n <= 0 { + return "" + } + i += n + case 2: // length-delimited + length, n := binary.Uvarint(data[i:]) + if n <= 0 { + return "" + } + i += n + if fieldNum == targetField { + end := i + int(length) + if end > len(data) { + return "" + } + return string(data[i:end]) + } + i += int(length) + case 1: // 64-bit + i += 8 + case 5: // 32-bit + i += 4 + default: + return "" + } + } + return "" +} + +// ============================================================ +// ConnectRPC envelope helpers +// ============================================================ + +// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope: +// 1 byte flags + 4 byte big-endian length + payload +func connectEnvelope(flags byte, payload []byte) []byte { + frame := make([]byte, 5+len(payload)) + frame[0] = flags + binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload))) + copy(frame[5:], payload) + return frame +} + +// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC. +// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata. +func connectEndOfStream() []byte { + trailer := []byte("{}") + return connectEnvelope(0x02, trailer) +} + +// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message. +// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is. +func unwrapConnectEnvelope(body []byte) []byte { + if len(body) < 5 { + return body + } + // Check if it looks like an envelope: first byte should be 0x00 or 0x01 + if body[0] > 0x02 { + return body // Not envelope-framed, return raw + } + plen := binary.BigEndian.Uint32(body[1:5]) + if int(plen)+5 > len(body) { + return body // Length mismatch, return raw + } + return body[5 : 5+plen] +} + +// ============================================================ +// OAuthTokenInfo proto builder +// ============================================================ + +// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto. +// +// message OAuthTokenInfo { +// string access_token = 1; +// string token_type = 2; +// string refresh_token = 3; +// google.protobuf.Timestamp expiry = 4; +// bool is_gcp_tos = 6; +// } +func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte { + var buf []byte + buf = append(buf, encodeProtoString(1, accessToken)...) + buf = append(buf, encodeProtoString(2, "Bearer")...) + if refreshToken != "" { + buf = append(buf, encodeProtoString(3, refreshToken)...) + } + // Use real expiry if provided, otherwise default to 1 hour from now + expiry := expiresAt + if expiry.IsZero() { + expiry = time.Now().Add(1 * time.Hour) + } + ts := ×tamppb.Timestamp{ + Seconds: expiry.Unix(), + } + tsBytes, _ := proto.Marshal(ts) + buf = append(buf, encodeProtoBytes(4, tsBytes)...) + buf = append(buf, encodeProtoBool(6, true)...) + return buf +} + +// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token. +// +// message Topic { map data = 1; } +// message Row { string value = 1; int64 e_tag = 2; } +// +// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is +// base64(toBinary(OAuthTokenInfo)). +func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte { + tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt) + tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) + + // Row: value=tokenB64 (field 1), e_tag=1 (field 2) + var row []byte + row = append(row, encodeProtoString(1, tokenB64)...) + row = append(row, encodeProtoVarint(2, 1)...) + + // Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2) + var entry []byte + entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...) + entry = append(entry, encodeProtoBytes(2, row)...) + + // Topic: data map entries use field 1 + var topic []byte + topic = append(topic, encodeProtoBytes(1, entry)...) + + return topic +} + +func buildPrimitiveBoolBinary(val bool) []byte { + // Primitive.bool_value is field 13 in the proto definition + return encodeProtoBool(13, val) +} + +func buildPrimitiveInt32Binary(val int32) []byte { + // Primitive.int32_value is field 3 in the proto definition + return encodeProtoVarint(3, uint64(uint32(val))) +} + +func buildUSSTopicRow(key string, value string) []byte { + row := buildUSSRowBinary(value) + + var entry []byte + entry = append(entry, encodeProtoString(1, key)...) + entry = append(entry, encodeProtoBytes(2, row)...) + return entry +} + +func buildUSSRowBinary(value string) []byte { + var row []byte + row = append(row, encodeProtoString(1, value)...) + row = append(row, encodeProtoVarint(2, 1)...) + return row +} + +func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte { + if info == nil { + info = &ModelCreditsInfo{} + } + + minimum := defaultMinimumCreditAmountForUsage + if info.MinimumCreditAmountForUsage != nil { + minimum = *info.MinimumCreditAmountForUsage + } + + entries := make([][]byte, 0, 3) + entries = append(entries, buildUSSTopicRow( + useAICreditsSentinelKey, + string(buildPrimitiveBoolBinary(info.UseAICredits)), + )) + // JS protocol: useAICreditsSentinelKey carries the toggle state. + // availableCreditsSentinelKey is only present when credits are enabled. + if info.UseAICredits { + credits := int32(9999) + if info.AvailableCredits != nil { + credits = *info.AvailableCredits + } + entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, string(buildPrimitiveInt32Binary(credits)))) + } + entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, string(buildPrimitiveInt32Binary(minimum)))) + + var topic []byte + for _, entry := range entries { + topic = append(topic, encodeProtoBytes(1, entry)...) + } + return topic +} + +// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics). +func buildEmptyTopic() []byte { + return []byte{} // Empty message = no map entries +} + +// ============================================================ +// UnifiedStateSyncUpdate builder +// ============================================================ + +// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set. +// +// message UnifiedStateSyncUpdate { +// oneof update_type { +// Topic initial_state = 1; +// AppliedUpdate applied_update = 2; +// } +// } +func buildInitialStateUpdate(topicData []byte) []byte { + return encodeProtoBytes(1, topicData) +} + +func buildAppliedUpdate(key string, row []byte) []byte { + var applied []byte + applied = append(applied, encodeProtoString(1, key)...) + if len(row) > 0 { + applied = append(applied, encodeProtoBytes(2, row)...) + } + return encodeProtoBytes(2, applied) +} + +// ============================================================ +// MockExtensionServer +// ============================================================ + +// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the +// Language Server binary connects to. It implements just enough of the +// ExtensionServerService to keep the LS operational. +type MockExtensionServer struct { + listener net.Listener + server *http.Server + port int + csrf string + mu sync.RWMutex + tokens map[string]*TokenInfo // account_id -> token info + credits map[string]*ModelCreditsInfo // account_id -> model credits info + subscribers map[string]map[int]*stateSubscriber + nextSubID int + lastAccountID string + logger *slog.Logger + + // Trajectory callback — when LS pushes trajectory updates, we forward them + onTrajectoryUpdate func(topic, key string, data []byte) +} + +// TokenInfo holds OAuth token details for an account. +type TokenInfo struct { + AccessToken string + RefreshToken string + ExpiresAt time.Time // zero value means unknown; defaults to now+1h +} + +// ModelCreditsInfo mirrors the JS uss-modelCredits topic state. +type ModelCreditsInfo struct { + UseAICredits bool + AvailableCredits *int32 + MinimumCreditAmountForUsage *int32 +} + +type stateSubscriber struct { + id int + accountID string + topic string + updates chan []byte +} + +const ( + useAICreditsSentinelKey = "useAICreditsSentinelKey" + availableCreditsSentinelKey = "availableCreditsSentinelKey" + minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey" + defaultMinimumCreditAmountForUsage = int32(50) +) + +// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling. +func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("listen: %w", err) + } + + m := &MockExtensionServer{ + listener: listener, + port: listener.Addr().(*net.TCPAddr).Port, + csrf: csrf, + tokens: make(map[string]*TokenInfo), + credits: make(map[string]*ModelCreditsInfo), + subscribers: make(map[string]map[int]*stateSubscriber), + logger: slog.Default().With("component", "mock-ext-server"), + } + + mux := http.NewServeMux() + extService := "/exa.extension_server_pb.ExtensionServerService/" + + // Register all RPCs the LS calls on the Extension Server. + // Unary RPCs — return application/proto + mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted)) + mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat)) + mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue)) + mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue)) + mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled)) + mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate)) + mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError)) + mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent)) + mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries)) + mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault)) + mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault)) + + // Server-streaming RPCs — return application/connect+proto + mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic)) + mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand)) + + // Catch-all for any unregistered RPCs + mux.HandleFunc("/", m.handleCatchAll) + + m.server = &http.Server{Handler: mux} + + go func() { + if err := m.server.Serve(listener); err != http.ErrServerClosed { + m.logger.Error("extension server error", "err", err) + } + }() + + m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf)) + return m, nil +} + +// Port returns the listening port. +func (m *MockExtensionServer) Port() int { + return m.port +} + +// SetToken sets the OAuth token for an account. +func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) { + m.mu.Lock() + m.tokens[accountID] = info + m.lastAccountID = accountID + subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID) + m.mu.Unlock() + + if info == nil { + return + } + tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt) + tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) + m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64))) +} + +// SetModelCredits sets the uss-modelCredits state for an account. +func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) { + if info == nil { + info = &ModelCreditsInfo{} + } + copyInfo := *info + m.mu.Lock() + m.credits[accountID] = ©Info + m.lastAccountID = accountID + subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID) + m.mu.Unlock() + + m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...) +} + +// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data. +func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) { + m.onTrajectoryUpdate = fn +} + +func (m *MockExtensionServer) currentTokenLocked() *TokenInfo { + if m.lastAccountID != "" { + if info := m.tokens[m.lastAccountID]; info != nil { + return info + } + } + for _, info := range m.tokens { + return info + } + return nil +} + +func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo { + if m.lastAccountID != "" { + if info := m.credits[m.lastAccountID]; info != nil { + return info + } + } + for _, info := range m.credits { + return info + } + return nil +} + +func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo { + if accountID != "" { + if info := m.tokens[accountID]; info != nil { + return info + } + } + return m.currentTokenLocked() +} + +func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo { + if accountID != "" { + if info := m.credits[accountID]; info != nil { + return info + } + } + return m.currentModelCreditsLocked() +} + +func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber { + topicSubs := m.subscribers[topic] + if len(topicSubs) == 0 { + return nil + } + out := make([]*stateSubscriber, 0, len(topicSubs)) + for _, sub := range topicSubs { + if sub == nil { + continue + } + if accountID != "" && sub.accountID != "" && sub.accountID != accountID { + continue + } + out = append(out, sub) + } + return out +} + +func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) { + for _, sub := range subscribers { + if sub == nil { + continue + } + for _, update := range updates { + if len(update) == 0 { + continue + } + payload := append([]byte(nil), update...) + select { + case sub.updates <- payload: + default: + m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID) + } + } + } +} + +func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte { + if info == nil { + info = &ModelCreditsInfo{} + } + minimum := defaultMinimumCreditAmountForUsage + if info.MinimumCreditAmountForUsage != nil { + minimum = *info.MinimumCreditAmountForUsage + } + + updates := make([][]byte, 0, 3) + updates = append(updates, buildAppliedUpdate( + useAICreditsSentinelKey, + buildUSSRowBinary(string(buildPrimitiveBoolBinary(info.UseAICredits))), + )) + + if info.UseAICredits { + credits := int32(9999) + if info.AvailableCredits != nil { + credits = *info.AvailableCredits + } + updates = append(updates, buildAppliedUpdate( + availableCreditsSentinelKey, + buildUSSRowBinary(string(buildPrimitiveInt32Binary(credits))), + )) + } else { + updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil)) + } + updates = append(updates, buildAppliedUpdate( + minimumCreditAmountForUsageKey, + buildUSSRowBinary(string(buildPrimitiveInt32Binary(minimum))), + )) + + return updates +} + +// Close shuts down the server. +func (m *MockExtensionServer) Close() error { + return m.server.Close() +} + +// ============================================================ +// Middleware +// ============================================================ + +type unaryHandler func(body []byte) []byte +type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request) + +// handleUnary wraps a unary RPC handler with CSRF check and proper content-type. +func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // CSRF check + if !m.checkCSRF(w, r) { + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + m.logger.Error("read body", "err", err, "path", r.URL.Path) + w.Header().Set("Content-Type", "application/proto") + w.WriteHeader(200) + return + } + + // The LS might send with envelope framing (application/connect+proto) + // or without (application/proto). Detect and unwrap. + ct := r.Header.Get("Content-Type") + protoBody := body + if strings.Contains(ct, "connect+proto") && len(body) >= 5 { + protoBody = unwrapConnectEnvelope(body) + } + + m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct) + + responseProto := handler(protoBody) + + // Respond with proper unary ConnectRPC content-type. + // If the request used "connect+proto", the response should be "application/proto" + // for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto). + w.Header().Set("Content-Type", "application/proto") + w.WriteHeader(200) + if len(responseProto) > 0 { + w.Write(responseProto) + } + } +} + +// handleStream wraps a server-streaming RPC handler with CSRF and content-type. +func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !m.checkCSRF(w, r) { + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + m.logger.Error("read body", "err", err, "path", r.URL.Path) + return + } + + // Unwrap envelope from request + ct := r.Header.Get("Content-Type") + if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") { + body = unwrapConnectEnvelope(body) + } + + m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body)) + + // Set streaming response content-type + w.Header().Set("Content-Type", "application/connect+proto") + w.WriteHeader(200) + + handler(body, w, r) + } +} + +func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool { + token := r.Header.Get("x-codeium-csrf-token") + if m.csrf != "" && token != m.csrf { + m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))]) + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(403) + w.Write([]byte("Invalid CSRF token")) + return false + } + return true +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// ============================================================ +// Unary RPC Handlers — each receives raw proto request body, +// returns raw proto response body. +// ============================================================ + +func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte { + // LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4) + // We just log the ports — they're informational. + m.logger.Info("LanguageServerStarted", + "body_len", len(body)) + // Return empty LanguageServerStartedResponse + return nil +} + +func (m *MockExtensionServer) onHeartbeat(body []byte) []byte { + // Return empty HeartbeatResponse + return nil +} + +func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte { + // GetSecretValueRequest: key = field 1 + key := decodeProtoString(body, 1) + m.logger.Debug("GetSecretValue", "key", key) + + m.mu.RLock() + var token string + if info := m.currentTokenLocked(); info != nil { + token = info.AccessToken + } + m.mu.RUnlock() + + // GetSecretValueResponse: value = field 1 + if token != "" { + return encodeProtoString(1, token) + } + return nil +} + +func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte { + key := decodeProtoString(body, 1) + m.logger.Debug("StoreSecretValue", "key", key) + return nil +} + +func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte { + // IsAgentManagerEnabledResponse: enabled = field 1 (bool) + return encodeProtoBool(1, false) +} + +func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte { + // PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message) + // UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2 + m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body)) + + // Extract topic name from the embedded UpdateRequest + // The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest + // We need to dig into the nested message to get topic_name + if m.onTrajectoryUpdate != nil { + // For now, just notify that an update was pushed + m.onTrajectoryUpdate("", "", body) + } + + // Return empty PushUnifiedStateSyncUpdateResponse + return nil +} + +func (m *MockExtensionServer) onRecordError(body []byte) []byte { + m.logger.Debug("RecordError", "body_len", len(body)) + return nil +} + +func (m *MockExtensionServer) onLogEvent(body []byte) []byte { + return nil +} + +func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte { + m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body)) + return nil +} + +func (m *MockExtensionServer) onDefault(body []byte) []byte { + return nil +} + +// ============================================================ +// Streaming RPC Handlers +// ============================================================ + +func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) { + // SubscribeToUnifiedStateSyncTopicRequest: topic = field 1 + topic := decodeProtoString(body, 1) + m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic) + + flusher, ok := w.(http.Flusher) + if !ok { + m.logger.Error("ResponseWriter does not support Flush") + return + } + + m.mu.Lock() + accountID := m.lastAccountID + subID := m.nextSubID + m.nextSubID++ + sub := &stateSubscriber{ + id: subID, + accountID: accountID, + topic: topic, + updates: make(chan []byte, 16), + } + if m.subscribers[topic] == nil { + m.subscribers[topic] = make(map[int]*stateSubscriber) + } + m.subscribers[topic][subID] = sub + + // Build initial state based on topic + var topicData []byte + switch topic { + case "uss-oauth": + tokenInfo := m.tokenForAccountLocked(accountID) + if tokenInfo != nil { + topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt) + } else { + topicData = buildEmptyTopic() + } + case "uss-modelCredits": + creditsInfo := m.creditsForAccountLocked(accountID) + if creditsInfo != nil { + topicData = buildUSSTopicWithModelCredits(creditsInfo) + } else { + topicData = buildEmptyTopic() + } + default: + // For all other topics (browserPreferences, enterprisePreferences, etc.), + // return empty topic data. + topicData = buildEmptyTopic() + } + m.mu.Unlock() + defer func() { + m.mu.Lock() + if topicSubs := m.subscribers[topic]; topicSubs != nil { + delete(topicSubs, subID) + if len(topicSubs) == 0 { + delete(m.subscribers, topic) + } + } + m.mu.Unlock() + }() + + // Send initial state as envelope-framed message + initialUpdate := buildInitialStateUpdate(topicData) + frame := connectEnvelope(0x00, initialUpdate) + w.Write(frame) + flusher.Flush() + + for { + select { + case <-r.Context().Done(): + m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic) + return + case update := <-sub.updates: + if len(update) == 0 { + continue + } + if _, err := w.Write(connectEnvelope(0x00, update)); err != nil { + m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err) + return + } + flusher.Flush() + } + } +} + +func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) { + m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body)) + // Send end-of-stream immediately — we don't execute commands + flusher, ok := w.(http.Flusher) + if !ok { + return + } + w.Write(connectEndOfStream()) + flusher.Flush() +} + +// ============================================================ +// Catch-all handler +// ============================================================ + +func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) { + if !m.checkCSRF(w, r) { + return + } + m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method) + + // Drain request body + io.ReadAll(r.Body) + + // Determine if this is likely a unary or streaming request based on content-type. + ct := r.Header.Get("Content-Type") + if strings.Contains(ct, "connect+") { + // Could be streaming — respond with unary proto to be safe + // (unary Connect requests can also use connect+ prefix in some client impls) + w.Header().Set("Content-Type", "application/proto") + } else { + w.Header().Set("Content-Type", "application/proto") + } + w.WriteHeader(200) +} diff --git a/backend/internal/pkg/lspool/pool.go b/backend/internal/pkg/lspool/pool.go new file mode 100644 index 00000000..afbc29c5 --- /dev/null +++ b/backend/internal/pkg/lspool/pool.go @@ -0,0 +1,1132 @@ +// Package lspool manages a pool of AntiGravity Language Server instances. +// +// Each Google account gets its own LS instance. The LS binary is Google's own +// compiled Go binary, so all upstream TLS fingerprints, session behavior, +// and protocol patterns are 100% authentic — indistinguishable from real IDE. +// +// Architecture: +// +// sub2API Gateway → LS Pool → LS Instance (per account) → cloudcode-pa +// +// Communication protocol (from JS source analysis): +// +// sub2API → LS: ConnectRPC over HTTPS/2, binary proto, x-codeium-csrf-token header +// LS → ExtServer: ConnectRPC over HTTP/1.1, binary proto, x-codeium-csrf-token header +// +// Unary calls: Content-Type: application/proto (no envelope framing) +// Stream calls: Content-Type: application/connect+proto (envelope-framed) +// Envelope = 1 byte flags + 4 byte BE length + payload +// flags=0x02 means end-of-stream trailer +package lspool + +import ( + "bufio" + "bytes" + "context" + crand "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http2" +) + +// ============================================================ +// Configuration +// ============================================================ + +// Config for the LS pool +type Config struct { + // AppRoot is the path to AntiGravity.app resources + // e.g., "/Applications/AntiGravity.app/Contents/Resources/app" + AppRoot string + + // CloudCodeEndpoint overrides the default Cloud Code endpoint + CloudCodeEndpoint string + + // MaxIdleTime before shutting down an idle LS instance + MaxIdleTime time.Duration + + // HealthCheckInterval between health checks + HealthCheckInterval time.Duration + + // ReplicasPerAccount controls how many LS processes a single account can use. + ReplicasPerAccount int +} + +// DefaultConfig returns production defaults +func DefaultConfig() Config { + return Config{ + AppRoot: findAppRoot(), + CloudCodeEndpoint: "https://cloudcode-pa.googleapis.com", + MaxIdleTime: 30 * time.Minute, + HealthCheckInterval: 30 * time.Second, + ReplicasPerAccount: parseLSReplicaCount(), + } +} + +func parseLSReplicaCount() int { + raw := strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT")) + if raw == "" { + return 5 + } + val, err := strconv.Atoi(raw) + if err != nil || val < 1 { + return 5 + } + return val +} + +func findAppRoot() string { + candidates := []string{ + "/Applications/AntiGravity.app/Contents/Resources/app", + "/Applications/Antigravity.app/Contents/Resources/app", + filepath.Join(os.Getenv("HOME"), ".local/share/antigravity/app"), + } + for _, c := range candidates { + if _, err := os.Stat(filepath.Join(c, "extensions", "antigravity", "bin")); err == nil { + return c + } + } + return candidates[0] +} + +// ============================================================ +// LS Instance +// ============================================================ + +// maxConcurrencyPerInstance limits how many concurrent Cascade calls a single +// LS instance handles. LS is designed for a single IDE user; beyond this +// threshold requests are rejected so the caller can fall back to direct HTTP. +const ( + maxConcurrencyPerInstance = 5 + lsStartupReadyTimeout = 6 * time.Second + lsStartupProbeInterval = 100 * time.Millisecond + lsStartupHeartbeatTimeout = 1 * time.Second +) + +// Instance represents a single Language Server process bound to one Google account +type Instance struct { + AccountID string + Email string + CSRF string + Replica int + Address string // e.g., "127.0.0.1:52444" + cmd *exec.Cmd + cleanup func() + client *http.Client + mu sync.RWMutex + healthy bool + lastUsed time.Time + startedAt time.Time + inflight int64 // atomic: current number of concurrent cascade calls + modelMapReady int32 + remote bool + workerToken string + routingKey string +} + +// AcquireConcurrency atomically increments the inflight counter. +// Returns false if the instance is already at max capacity. +func (i *Instance) AcquireConcurrency() bool { + for { + cur := atomic.LoadInt64(&i.inflight) + if cur >= int64(maxConcurrencyPerInstance) { + return false + } + if atomic.CompareAndSwapInt64(&i.inflight, cur, cur+1) { + return true + } + } +} + +// ReleaseConcurrency decrements the inflight counter. +func (i *Instance) ReleaseConcurrency() { + atomic.AddInt64(&i.inflight, -1) +} + +// ConcurrentCount returns the current number of in-flight cascade calls. +func (i *Instance) ConcurrentCount() int64 { + return atomic.LoadInt64(&i.inflight) +} + +// SetModelMappingReady records whether this LS instance has successfully loaded +// its model config from the upstream service. +func (i *Instance) SetModelMappingReady(ready bool) { + if ready { + atomic.StoreInt32(&i.modelMapReady, 1) + return + } + atomic.StoreInt32(&i.modelMapReady, 0) +} + +// HasModelMappingReady reports whether this LS instance has completed model +// config loading successfully. +func (i *Instance) HasModelMappingReady() bool { + return atomic.LoadInt32(&i.modelMapReady) == 1 +} + +// IsHealthy returns whether the instance is healthy +func (i *Instance) IsHealthy() bool { + i.mu.RLock() + defer i.mu.RUnlock() + return i.healthy +} + +// Touch marks the instance as recently used +func (i *Instance) Touch() { + i.mu.Lock() + i.lastUsed = time.Now() + i.mu.Unlock() +} + +// ============================================================ +// RPC Methods — uses ConnectRPC binary proto +// ============================================================ + +const ( + LSService = "exa.language_server_pb.LanguageServerService" +) + +// CallUnaryJSON makes a ConnectRPC JSON unary call (for convenience/debugging). +func (i *Instance) CallUnaryJSON(ctx context.Context, service, method string, input any) ([]byte, error) { + i.Touch() + + if i.remote { + body, err := marshalWorkerJSONBody(input) + if err != nil { + return nil, fmt.Errorf("marshal input: %w", err) + } + return i.callWorkerUnary(ctx, service, method, "json", body) + } + + url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) + + var body []byte + if input != nil { + var err error + body, err = json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("marshal input: %w", err) + } + } else { + body = []byte("{}") + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connect-Protocol-Version", "1") + req.Header.Set("x-codeium-csrf-token", i.CSRF) + + resp, err := i.client.Do(req) + if err != nil { + return nil, fmt.Errorf("rpc %s/%s: %w", service, method, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != 200 { + return respBody, fmt.Errorf("rpc %s/%s HTTP %d: %s", + service, method, resp.StatusCode, truncate(string(respBody), 200)) + } + + return respBody, nil +} + +// CallRPC makes a ConnectRPC binary proto unary call to the LS. +// Uses Content-Type: application/proto (Connect protocol unary). +func (i *Instance) CallRPC(ctx context.Context, service, method string, protoBody []byte) ([]byte, error) { + i.Touch() + + if i.remote { + return i.callWorkerUnary(ctx, service, method, "proto", protoBody) + } + + url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) + + if protoBody == nil { + protoBody = []byte{} + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(protoBody)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/proto") + req.Header.Set("Connect-Protocol-Version", "1") + req.Header.Set("x-codeium-csrf-token", i.CSRF) + + resp, err := i.client.Do(req) + if err != nil { + return nil, fmt.Errorf("rpc %s/%s: %w", service, method, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != 200 { + return respBody, fmt.Errorf("rpc %s/%s HTTP %d: %s", + service, method, resp.StatusCode, truncate(string(respBody), 200)) + } + + return respBody, nil +} + +// StreamRPC makes a server-streaming ConnectRPC call, returning the raw response. +// Uses Content-Type: application/connect+proto with envelope framing. +func (i *Instance) StreamRPC(ctx context.Context, service, method string, protoBody []byte) (*http.Response, error) { + i.Touch() + + if i.remote { + return i.callWorkerStream(ctx, service, method, "proto", protoBody) + } + + url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) + + if protoBody == nil { + protoBody = []byte{} + } + + // Wrap in Connect envelope for streaming request + framedBody := frameConnectMessage(protoBody) + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(framedBody)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/connect+proto") + req.Header.Set("Connect-Protocol-Version", "1") + req.Header.Set("Connect-Content-Encoding", "identity") + req.Header.Set("x-codeium-csrf-token", i.CSRF) + + return i.client.Do(req) +} + +// StreamRPCJSON makes a server-streaming ConnectRPC JSON call (for debugging). +func (i *Instance) StreamRPCJSON(ctx context.Context, service, method string, input any) (*http.Response, error) { + i.Touch() + + if i.remote { + body, err := marshalWorkerJSONBody(input) + if err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + return i.callWorkerStream(ctx, service, method, "json", body) + } + + url := fmt.Sprintf("https://%s/%s/%s", i.Address, service, method) + + var body []byte + if input != nil { + var err error + body, err = json.Marshal(input) + if err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + } else { + body = []byte("{}") + } + + // Wrap in Connect envelope for streaming request + body = frameConnectMessage(body) + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/connect+json") + req.Header.Set("Connect-Protocol-Version", "1") + req.Header.Set("Connect-Content-Encoding", "identity") + req.Header.Set("x-codeium-csrf-token", i.CSRF) + + return i.client.Do(req) +} + +func frameConnectMessage(payload []byte) []byte { + framed := make([]byte, 5+len(payload)) + binary.BigEndian.PutUint32(framed[1:5], uint32(len(payload))) + copy(framed[5:], payload) + return framed +} + +// Heartbeat sends a heartbeat to the LS +func (i *Instance) Heartbeat(ctx context.Context) error { + return i.HeartbeatWithTimeout(ctx, 15*time.Second) +} + +// HeartbeatWithTimeout sends a heartbeat to the LS with an explicit deadline. +func (i *Instance) HeartbeatWithTimeout(ctx context.Context, timeout time.Duration) error { + if timeout <= 0 { + timeout = 15 * time.Second + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + _, err := i.CallRPC(ctx, LSService, "Heartbeat", nil) + return err +} + +func waitForInstanceReady( + ctx context.Context, + probeInterval time.Duration, + heartbeat func(context.Context) error, +) (int, error) { + if probeInterval <= 0 { + probeInterval = lsStartupProbeInterval + } + if heartbeat == nil { + return 0, fmt.Errorf("heartbeat func is nil") + } + + timer := time.NewTimer(0) + defer timer.Stop() + + attempts := 0 + var lastErr error + + for { + select { + case <-ctx.Done(): + if lastErr == nil { + lastErr = ctx.Err() + } + return attempts, lastErr + case <-timer.C: + } + + attempts++ + if err := heartbeat(ctx); err == nil { + return attempts, nil + } else { + lastErr = err + } + + timer.Reset(probeInterval) + } +} + +// ============================================================ +// Pool Manager +// ============================================================ + +// Pool manages multiple LS instances, with sticky session routing per account. +type Pool struct { + config Config + instances map[string][]*Instance // accountID -> replica slot -> instance + extServer *MockExtensionServer // shared mock extension server + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + logger *slog.Logger +} + +// NewPool creates a new LS pool with lifecycle management +func NewPool(config Config) *Pool { + ctx, cancel := context.WithCancel(context.Background()) + + // Generate a shared CSRF token for communication between LS and ext server + csrf := generateUUID() + + // Start the mock extension server + extServer, err := NewMockExtensionServer(csrf) + if err != nil { + slog.Error("failed to start mock extension server", "err", err) + } + + p := &Pool{ + config: config, + instances: make(map[string][]*Instance), + extServer: extServer, + ctx: ctx, + cancel: cancel, + logger: slog.Default().With("component", "lspool"), + } + go p.lifecycleLoop() + return p +} + +func (p *Pool) replicaCount() int { + if p.config.ReplicasPerAccount < 1 { + return 1 + } + return p.config.ReplicasPerAccount +} + +func (p *Pool) ensureLogger() { + if p.logger == nil { + p.logger = slog.Default().With("component", "lspool") + } +} + +func replicaSlotIndex(routingKey string, replicaCount int) int { + if replicaCount <= 1 || strings.TrimSpace(routingKey) == "" { + return 0 + } + sum := sha256.Sum256([]byte(routingKey)) + return int(binary.BigEndian.Uint32(sum[:4]) % uint32(replicaCount)) +} + +func chooseLeastBusyHealthy(instances []*Instance) *Instance { + var best *Instance + for _, inst := range instances { + if inst == nil || !inst.IsHealthy() { + continue + } + if best == nil || inst.ConcurrentCount() < best.ConcurrentCount() { + best = inst + } + } + return best +} + +func (p *Pool) ensureReplicaSlotsLocked(accountID string) []*Instance { + slots := p.instances[accountID] + required := p.replicaCount() + if len(slots) < required { + expanded := make([]*Instance, required) + copy(expanded, slots) + slots = expanded + p.instances[accountID] = slots + } + return slots +} + +// Get returns an existing healthy LS instance for the account and routing key, or nil. +func (p *Pool) Get(accountID, routingKey string) *Instance { + p.mu.RLock() + defer p.mu.RUnlock() + + instances := p.instances[accountID] + if len(instances) == 0 { + return nil + } + if strings.TrimSpace(routingKey) != "" { + slot := replicaSlotIndex(routingKey, p.replicaCount()) + if slot < len(instances) { + inst := instances[slot] + if inst != nil && inst.IsHealthy() { + inst.Touch() + return inst + } + } + return nil + } + if inst := chooseLeastBusyHealthy(instances); inst != nil { + inst.Touch() + return inst + } + return nil +} + +// GetOrCreate returns an existing LS or starts a new one. +// proxyURL is passed to the LS process as HTTPS_PROXY for Google API connectivity. +func (p *Pool) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) { + p.ensureLogger() + if inst := p.Get(accountID, routingKey); inst != nil { + return inst, nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + slots := p.ensureReplicaSlotsLocked(accountID) + slot := replicaSlotIndex(routingKey, p.replicaCount()) + + if strings.TrimSpace(routingKey) == "" { + if inst := chooseLeastBusyHealthy(slots); inst != nil { + inst.Touch() + return inst, nil + } + for idx, inst := range slots { + if inst == nil { + slot = idx + break + } + } + } + + if slot < len(slots) { + if inst := slots[slot]; inst != nil && inst.IsHealthy() { + inst.Touch() + return inst, nil + } + } + + proxy := "" + if len(proxyURL) > 0 { + proxy = proxyURL[0] + } + + if slot < len(slots) && slots[slot] != nil { + p.stopInstance(slots[slot]) + slots[slot] = nil + } + + inst, err := p.startInstance(accountID, proxy, slot) + if err != nil { + return nil, err + } + + slots[slot] = inst + p.logger.Info("LS instance created", + "account", shortAccountID(accountID), + "replica", slot, + "address", inst.Address, + "pid", inst.cmd.Process.Pid) + return inst, nil +} + +// Remove stops and removes all LS instances for an account. +func (p *Pool) Remove(accountID string) { + p.mu.Lock() + defer p.mu.Unlock() + + if slots, ok := p.instances[accountID]; ok { + for _, inst := range slots { + if inst == nil { + continue + } + p.stopInstance(inst) + } + delete(p.instances, accountID) + } +} + +// SetAccountToken updates the OAuth token for an account in the mock extension server +func (p *Pool) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { + if p.extServer != nil { + p.extServer.SetToken(accountID, &TokenInfo{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresAt: expiresAt, + }) + } +} + +// SetAccountModelCredits updates the JS-parity uss-modelCredits state for an account. +func (p *Pool) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { + if p.extServer != nil { + p.extServer.SetModelCredits(accountID, &ModelCreditsInfo{ + UseAICredits: useAICredits, + AvailableCredits: availableCredits, + MinimumCreditAmountForUsage: minimumCreditAmountForUsage, + }) + } +} + +// Stats returns pool statistics +func (p *Pool) Stats() map[string]any { + p.mu.RLock() + defer p.mu.RUnlock() + + active := 0 + total := 0 + for _, slots := range p.instances { + for _, inst := range slots { + if inst == nil { + continue + } + total++ + if inst.IsHealthy() { + active++ + } + } + } + return map[string]any{ + "accounts": len(p.instances), + "total": total, + "active": active, + } +} + +// Close shuts down all instances and the extension server +func (p *Pool) Close() { + p.ensureLogger() + p.cancel() + p.mu.Lock() + defer p.mu.Unlock() + + for id, slots := range p.instances { + for _, inst := range slots { + if inst == nil { + continue + } + p.logger.Info("shutting down LS", "account", shortAccountID(id), "replica", inst.Replica) + p.stopInstance(inst) + } + } + p.instances = make(map[string][]*Instance) + + if p.extServer != nil { + p.extServer.Close() + } +} + +// ============================================================ +// Instance Lifecycle +// ============================================================ + +var portRe = regexp.MustCompile(`at (\d+) for HTTPS`) + +func (p *Pool) startInstance(accountID string, proxyURL string, replica int) (*Instance, error) { + binPath := filepath.Join(p.config.AppRoot, "extensions", "antigravity", "bin", lsBinaryName()) + if _, err := os.Stat(binPath); err != nil { + return nil, fmt.Errorf("LS binary not found: %s", binPath) + } + + // Each LS instance gets its own CSRF token (like the real IDE) + csrf := generateUUID() + appDataDir := fmt.Sprintf("antigravity-pool-%s-r%d", shortAccountID(accountID), replica) + + args := []string{ + "--csrf_token", csrf, + "--app_data_dir", appDataDir, + "--https_server_port", "0", + } + if p.config.CloudCodeEndpoint != "" { + args = append(args, "--cloud_code_endpoint", p.config.CloudCodeEndpoint) + } + // Connect LS to our mock extension server for token injection. + // The extension server uses a shared CSRF token (set at server creation). + if p.extServer != nil { + args = append(args, + "--extension_server_port", fmt.Sprintf("%d", p.extServer.Port()), + "--extension_server_csrf_token", p.extServer.csrf, + ) + } + + rawProxyURL := resolveLSProxy(proxyURL) + launchPlan, err := prepareLSLaunchPlan(binPath, args, rawProxyURL) + if err != nil { + return nil, fmt.Errorf("prepare LS launch: %w", err) + } + + cmd := launchPlan.cmd + cmd.Env = buildLSEnv(os.Environ(), p.config.AppRoot, launchPlan.effectiveProxyURL) + p.logger.Info("LS starting", + "account", shortAccountID(accountID), + "replica", replica, + "proxy_source", rawProxyURL, + "proxy_mode", launchPlan.proxyMode, + "effective_proxy", launchPlan.effectiveProxyURL) + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("stdin pipe: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + if launchPlan.cleanup != nil { + launchPlan.cleanup() + } + return nil, fmt.Errorf("start LS: %w", err) + } + + // Write metadata proto to stdin (with access token if available) + accessToken := "" + if p.extServer != nil { + p.extServer.mu.RLock() + if info := p.extServer.tokens[accountID]; info != nil { + accessToken = info.AccessToken + } + p.extServer.mu.RUnlock() + } + metaBytes := buildMetadataBytes(accessToken) + p.logger.Info("writing metadata to LS stdin", + "account", shortAccountID(accountID), + "replica", replica, + "meta_len", len(metaBytes), + "has_token", accessToken != "", + "hex_prefix", fmt.Sprintf("%x", metaBytes[:min(40, len(metaBytes))])) + stdin.Write(metaBytes) + stdin.Close() + + // Parse HTTPS port from stderr AND log all LS output + portCh := make(chan string, 1) + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + line := scanner.Text() + p.logger.Warn("LS stderr", "account", shortAccountID(accountID), "replica", replica, "line", line) + if matches := portRe.FindStringSubmatch(line); len(matches) > 1 { + portCh <- matches[1] + } + } + p.logger.Warn("LS stderr EOF", "account", shortAccountID(accountID), "replica", replica) + }() + + var address string + select { + case port := <-portCh: + address = "127.0.0.1:" + port + case <-time.After(15 * time.Second): + cmd.Process.Kill() + return nil, fmt.Errorf("timeout: LS did not report HTTPS port within 15s") + } + + // Create HTTPS client with LS self-signed cert + httpClient, err := createHTTPClient(p.config.AppRoot, csrf) + if err != nil { + cmd.Process.Kill() + return nil, fmt.Errorf("create http client: %w", err) + } + + inst := &Instance{ + AccountID: accountID, + CSRF: csrf, + Replica: replica, + Address: address, + cmd: cmd, + cleanup: launchPlan.cleanup, + client: httpClient, + healthy: false, + lastUsed: time.Now(), + startedAt: time.Now(), + } + + // Real IDE waits for LanguageServerStarted callback from ExtServer (timeout 60s). + // Our MockExtServer already received this callback during port detection + // (LS calls LanguageServerStarted right after binding HTTPS port). + // Probe readiness immediately and keep retrying with a short interval instead + // of sleeping a fixed 3-6 seconds on every cold start. + p.logger.Info("waiting for LS readiness", "account", shortAccountID(accountID), "replica", replica, "address", address) + readyStartedAt := time.Now() + readyCtx, cancel := context.WithTimeout(context.Background(), lsStartupReadyTimeout) + attempts, err := waitForInstanceReady(readyCtx, lsStartupProbeInterval, func(callCtx context.Context) error { + return inst.HeartbeatWithTimeout(callCtx, lsStartupHeartbeatTimeout) + }) + cancel() + if err != nil { + cmd.Process.Kill() + return nil, fmt.Errorf("LS not ready after startup: %w", err) + } + inst.mu.Lock() + inst.healthy = true + inst.mu.Unlock() + p.logger.Info("LS ready", + "account", shortAccountID(accountID), + "replica", replica, + "attempts", attempts, + "waited", time.Since(readyStartedAt).Truncate(time.Millisecond)) + + // Refresh model mapping from LS (async with retries — don't block startup) + go func() { + for attempt := 1; attempt <= 5; attempt++ { + if RefreshModelMapping(inst) { + p.logger.Info("model mapping loaded", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt) + return + } + p.logger.Warn("model mapping not loaded, retrying", "account", shortAccountID(accountID), "replica", replica, "attempt", attempt) + time.Sleep(time.Duration(attempt*10) * time.Second) + } + }() + + return inst, nil +} + +func (p *Pool) stopInstance(inst *Instance) { + if inst.cmd != nil && inst.cmd.Process != nil { + inst.cmd.Process.Kill() + inst.cmd.Wait() + } + if inst.cleanup != nil { + inst.cleanup() + } + inst.mu.Lock() + inst.healthy = false + inst.mu.Unlock() +} + +func (p *Pool) lifecycleLoop() { + ticker := time.NewTicker(p.config.HealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-p.ctx.Done(): + return + case <-ticker.C: + p.doHealthCheck() + } + } +} + +func (p *Pool) doHealthCheck() { + p.mu.Lock() + defer p.mu.Unlock() + + for id, slots := range p.instances { + for replica, inst := range slots { + if inst == nil { + continue + } + + // Check process alive + if inst.cmd.ProcessState != nil { + p.logger.Warn("LS process exited", "account", shortAccountID(id), "replica", replica) + slots[replica] = nil + continue + } + + // Check idle timeout + inst.mu.RLock() + idle := time.Since(inst.lastUsed) + inst.mu.RUnlock() + + if idle > p.config.MaxIdleTime { + p.logger.Info("LS idle timeout", "account", shortAccountID(id), "replica", replica, "idle", idle) + p.stopInstance(inst) + slots[replica] = nil + continue + } + + // Heartbeat check + if err := inst.Heartbeat(p.ctx); err != nil { + p.logger.Warn("heartbeat failed", "account", shortAccountID(id), "replica", replica, "err", err) + inst.mu.Lock() + inst.healthy = false + inst.mu.Unlock() + } else { + inst.mu.Lock() + inst.healthy = true + inst.mu.Unlock() + } + } + + allNil := true + for _, inst := range slots { + if inst != nil { + allNil = false + break + } + } + if allNil { + delete(p.instances, id) + } + } +} + +// ============================================================ +// Helpers +// ============================================================ + +var ( + defaultLSCertFileCandidates = []string{ + "/etc/ssl/certs/ca-certificates.crt", + "/etc/pki/tls/certs/ca-bundle.crt", + "/etc/ssl/cert.pem", + } + defaultLSCertDirCandidates = []string{ + "/etc/ssl/certs", + "/etc/pki/tls/certs", + } +) + +func buildLSEnv(baseEnv []string, appRoot string, proxyURL string) []string { + env := append([]string(nil), baseEnv...) + env = setEnvValue(env, "ANTIGRAVITY_EDITOR_APP_ROOT", appRoot) + + // Set proxy for LS to reach Google APIs. + // MUST always override inherited container proxy (which may be Anthropic-only). + // proxyURL is already fully resolved by the caller. + // Always set — even empty string clears inherited container values + env = setEnvValue(env, "HTTPS_PROXY", proxyURL) + env = setEnvValue(env, "HTTP_PROXY", proxyURL) + env = setEnvValue(env, "ALL_PROXY", proxyURL) + env = setEnvValue(env, "https_proxy", proxyURL) + env = setEnvValue(env, "http_proxy", proxyURL) + env = setEnvValue(env, "all_proxy", proxyURL) + + if !hasEnvKey(env, "SSL_CERT_FILE") { + if certFile := firstExistingPath(defaultLSCertFileCandidates); certFile != "" { + env = setEnvValue(env, "SSL_CERT_FILE", certFile) + } + } + if !hasEnvKey(env, "SSL_CERT_DIR") { + if certDir := firstExistingPath(defaultLSCertDirCandidates); certDir != "" { + env = setEnvValue(env, "SSL_CERT_DIR", certDir) + } + } + + return env +} + +func resolveLSProxy(proxyURL string) string { + if strings.TrimSpace(proxyURL) != "" { + return strings.TrimSpace(proxyURL) + } + return strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_PROXY")) +} + +func firstExistingPath(candidates []string) string { + for _, candidate := range candidates { + if candidate == "" { + continue + } + if _, err := os.Stat(candidate); err == nil { + return candidate + } + } + return "" +} + +func hasEnvKey(env []string, key string) bool { + prefix := key + "=" + for _, entry := range env { + if strings.HasPrefix(entry, prefix) { + return true + } + } + return false +} + +func setEnvValue(env []string, key, value string) []string { + prefix := key + "=" + for i, entry := range env { + if strings.HasPrefix(entry, prefix) { + env[i] = prefix + value + return env + } + } + return append(env, prefix+value) +} + +func createHTTPClient(appRoot, csrf string) (*http.Client, error) { + certPath := filepath.Join(appRoot, "extensions", "antigravity", "dist", "languageServer", "cert.pem") + caCert, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("read cert %s: %w", certPath, err) + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse cert") + } + + tlsCfg := &tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: true, // LS uses self-signed cert; trust it unconditionally + } + + return &http.Client{ + Transport: &csrfTransport{ + base: &http2.Transport{ + TLSClientConfig: tlsCfg, + ReadIdleTimeout: 30 * time.Second, + }, + csrf: csrf, + }, + Timeout: 5 * time.Minute, + }, nil +} + +type csrfTransport struct { + base http.RoundTripper + csrf string +} + +func (t *csrfTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("x-codeium-csrf-token", t.csrf) + return t.base.RoundTrip(req) +} + +// writeMetadata writes the Metadata proto to LS stdin. +// This matches what the real IDE does (extension.js line 42520-42522): +// +// toBinary(MetadataSchema, create(MetadataSchema, { +// ideName, ideVersion, extensionName, extensionPath, +// locale, deviceFingerprint, apiKey, disableTelemetry, userTierId +// })) +func buildMetadataBytes(accessToken string) []byte { + var buf bytes.Buffer + writeProtoStringField(&buf, 1, "antigravity") // ide_name + writeProtoStringField(&buf, 7, "1.107.0") // ide_version + writeProtoStringField(&buf, 12, "antigravity") // extension_name + writeProtoStringField(&buf, 4, "en") // locale + if accessToken != "" { + writeProtoStringField(&buf, 3, accessToken) // api_key = access_token + } + // disable_telemetry = true (field 6, varint, value 1) + buf.Write([]byte{0x30, 0x01}) + return buf.Bytes() +} + +func writeProtoStringField(buf *bytes.Buffer, fieldNum int, val string) { + writeVarInt(buf, uint64(fieldNum<<3|2)) + writeVarInt(buf, uint64(len(val))) + buf.WriteString(val) +} + +func writeVarInt(buf *bytes.Buffer, v uint64) { + b := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(b, v) + buf.Write(b[:n]) +} + +func lsBinaryName() string { + m := map[string]string{ + "darwin/arm64": "language_server_macos_arm", + "darwin/amd64": "language_server_macos_x64", + "linux/amd64": "language_server_linux_x64", + "linux/arm64": "language_server_linux_arm", + "windows/amd64": "language_server_windows_x64.exe", + } + if name, ok := m[runtime.GOOS+"/"+runtime.GOARCH]; ok { + return name + } + return "language_server_linux_x64" // fallback +} + +func generateUUID() string { + b := make([]byte, 16) + crand.Read(b) + b[6] = (b[6] & 0x0f) | 0x40 + b[8] = (b[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max] + "..." +} + +func shortAccountID(accountID string) string { + if len(accountID) <= 8 { + return accountID + } + return accountID[:8] +} diff --git a/backend/internal/pkg/lspool/pool_test.go b/backend/internal/pkg/lspool/pool_test.go new file mode 100644 index 00000000..c2967dc3 --- /dev/null +++ b/backend/internal/pkg/lspool/pool_test.go @@ -0,0 +1,346 @@ +package lspool + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/stretchr/testify/require" +) + +func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) { + env := buildLSEnv([]string{ + "SSL_CERT_FILE=/custom/ca.pem", + "SSL_CERT_DIR=/custom/certs", + }, "/opt/antigravity", "") + require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity") + require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem") + require.Contains(t, env, "SSL_CERT_DIR=/custom/certs") +} + +func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) { + env := buildLSEnv([]string{ + "HTTPS_PROXY=http://old-proxy:8080", + "HTTP_PROXY=http://old-proxy:8080", + "ALL_PROXY=socks5://old-proxy:1080", + "https_proxy=http://old-proxy:8080", + "http_proxy=http://old-proxy:8080", + "all_proxy=socks5://old-proxy:1080", + }, "/opt/antigravity", "") + + require.Contains(t, env, "HTTPS_PROXY=") + require.Contains(t, env, "HTTP_PROXY=") + require.Contains(t, env, "ALL_PROXY=") + require.Contains(t, env, "https_proxy=") + require.Contains(t, env, "http_proxy=") + require.Contains(t, env, "all_proxy=") +} + +func TestShortAccountID(t *testing.T) { + require.Equal(t, "9", shortAccountID("9")) + require.Equal(t, "12345678", shortAccountID("12345678")) + require.Equal(t, "12345678", shortAccountID("123456789")) +} + +func TestFrameConnectMessage(t *testing.T) { + framed := frameConnectMessage([]byte(`{"x":1}`)) + require.Len(t, framed, 5+len(`{"x":1}`)) + require.Equal(t, byte(0), framed[0]) + require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5])) + require.Equal(t, `{"x":1}`, string(framed[5:])) +} + +func TestConnectEnvelope(t *testing.T) { + payload := []byte("hello") + env := connectEnvelope(0x00, payload) + require.Len(t, env, 5+len(payload)) + require.Equal(t, byte(0x00), env[0]) + require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5])) + require.Equal(t, "hello", string(env[5:])) +} + +func TestUnwrapConnectEnvelope(t *testing.T) { + payload := []byte("test data") + env := connectEnvelope(0x00, payload) + unwrapped := unwrapConnectEnvelope(env) + require.Equal(t, payload, unwrapped) + short := []byte{1, 2} + require.Equal(t, short, unwrapConnectEnvelope(short)) +} + +func TestExtractPromptAndModel(t *testing.T) { + body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}` + prompt, model := extractPromptAndModel([]byte(body)) + require.Equal(t, "hello world", prompt) + require.Equal(t, "gemini-2.5-pro", model) + + body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}` + prompt2, _ := extractPromptAndModel([]byte(body2)) + require.Equal(t, "test prompt", prompt2) +} + +func TestResolveModelEnum(t *testing.T) { + // Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash) + require.True(t, resolveModelEnum("gemini-2.5-flash") > 0) + require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0) + require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0) + require.True(t, resolveModelEnum("unknown-model") > 0) +} + +func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) { + cfg := buildCascadeConfig("models/gemini-2.5-flash") + require.NotNil(t, cfg) + + plannerConfig, ok := cfg["plannerConfig"].(map[string]any) + require.True(t, ok) + requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) + require.True(t, ok) + require.NotEmpty(t, requestedModel["model"]) + require.Len(t, plannerConfig, 1) +} + +func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) { + cfg := buildCascadeConfig("claude-sonnet-4-6") + require.NotNil(t, cfg) + + plannerConfig, ok := cfg["plannerConfig"].(map[string]any) + require.True(t, ok) + requestedModel, ok := plannerConfig["requestedModel"].(map[string]any) + require.True(t, ok) + require.NotEmpty(t, requestedModel["model"]) + require.Len(t, plannerConfig, 1) +} + +func TestDoNonStreamGeneratePassesThrough(t *testing.T) { + fallback := &recordingUpstream{} + upstream := NewLSPoolUpstream(&Pool{}, fallback) + req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`))) + resp, err := upstream.Do(req, "", 1, 1) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 1, fallback.doCalls) +} + +func TestExtractPlannerResponseText(t *testing.T) { + resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[ + {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}, + {"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE", + "plannerResponse":{"response":"Hello world"}} + ]}}` + text, generating, status := extractPlannerResponseText([]byte(resp)) + require.Equal(t, "Hello world", text) + require.False(t, generating) + require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status) +} + +func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) { + resp := `{ + "status":"CASCADE_RUN_STATUS_IDLE", + "trajectory":{ + "steps":[ + {"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"} + ], + "executorMetadata":{ + "terminationReason":"ERROR", + "errorDetails":{ + "errorCode":429, + "shortError":"Model quota reached", + "details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s." + } + } + } + }` + + state := extractPlannerResponseState([]byte(resp)) + require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status) + require.False(t, state.Generating) + require.Empty(t, state.Text) + require.Contains(t, state.ErrorMessage, "Model quota reached") + require.Contains(t, state.ErrorMessage, "quota will reset after") +} + +func TestBuildGeminiSSEChunk(t *testing.T) { + sse := buildGeminiSSEChunk("hello") + require.Contains(t, sse, "data: ") + require.Contains(t, sse, `"text":"hello"`) + require.Contains(t, sse, `"role":"model"`) + require.True(t, strings.HasSuffix(sse, "\n\n")) +} + +func TestRequestHasTools(t *testing.T) { + // Wrapped format with tools + require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`))) + + // Direct format with tools + require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`))) + + // No tools + require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`))) + + // Empty tools array + require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`))) +} + +func TestCurrentLSStrategy(t *testing.T) { + t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity") + require.Equal(t, LSStrategyJSParity, CurrentLSStrategy()) + + t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown") + require.Equal(t, LSStrategyDirect, CurrentLSStrategy()) +} + +func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) { + t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "") + require.Equal(t, 5, parseLSReplicaCount()) + + t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3") + require.Equal(t, 3, parseLSReplicaCount()) + + t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0") + require.Equal(t, 5, parseLSReplicaCount()) +} + +func TestPoolGetUsesStickyReplicaSlot(t *testing.T) { + pool := &Pool{ + config: Config{ReplicasPerAccount: 5}, + instances: map[string][]*Instance{ + "acc-1": { + {AccountID: "acc-1", Replica: 0, healthy: true}, + {AccountID: "acc-1", Replica: 1, healthy: true}, + {AccountID: "acc-1", Replica: 2, healthy: true}, + {AccountID: "acc-1", Replica: 3, healthy: true}, + {AccountID: "acc-1", Replica: 4, healthy: true}, + }, + }, + } + + routingKey := "acc-1:user-a:session-1" + slot := replicaSlotIndex(routingKey, pool.replicaCount()) + inst := pool.Get("acc-1", routingKey) + require.NotNil(t, inst) + require.Equal(t, slot, inst.Replica) +} + +func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) { + busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true} + atomic.StoreInt64(&busy.inflight, 4) + idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true} + atomic.StoreInt64(&idle.inflight, 1) + + pool := &Pool{ + config: Config{ReplicasPerAccount: 5}, + instances: map[string][]*Instance{ + "acc-1": {busy, idle}, + }, + } + + inst := pool.Get("acc-1", "") + require.NotNil(t, inst) + require.Equal(t, 1, inst.Replica) +} + +func TestWaitForInstanceReadyProbesImmediately(t *testing.T) { + startedAt := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error { + return nil + }) + require.NoError(t, err) + require.Equal(t, 1, attempts) + require.Less(t, time.Since(startedAt), 100*time.Millisecond) +} + +func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + calls := 0 + attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error { + calls++ + if calls < 3 { + return errors.New("not ready") + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 3, attempts) + require.Equal(t, 3, calls) +} + +func TestDecideJSParityRoute(t *testing.T) { + body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) + parsed, err := parseGeminiRequest(body) + require.NoError(t, err) + decision := decideJSParityRoute(parsed, body) + require.True(t, decision.UseLS) + + imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`) + parsedImage, err := parseGeminiRequest(imageBody) + require.NoError(t, err) + decisionImage := decideJSParityRoute(parsedImage, imageBody) + require.False(t, decisionImage.UseLS) + require.Contains(t, strings.ToLower(decisionImage.Reason), "image") + + noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`) + parsedNoSession, err := parseGeminiRequest(noSessionBody) + require.NoError(t, err) + require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS) +} + +func TestUserNamespacePrefersExplicitHeader(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "https://example.com", nil) + require.NoError(t, err) + req.Header.Set(userNamespaceHeader, "tenant-a") + req.Header.Set("Authorization", "Bearer oauth-token") + + nsWithExplicit := userNamespace(req) + require.NotEqual(t, "anon", nsWithExplicit) + + req.Header.Del(userNamespaceHeader) + nsWithAuth := userNamespace(req) + require.NotEqual(t, "anon", nsWithAuth) + require.NotEqual(t, nsWithExplicit, nsWithAuth) +} + +func TestConversationPrefixEqual(t *testing.T) { + prefix := []geminiConversationTurn{ + {Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}}, + {Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}}, + } + full := append(cloneConversationTurns(prefix), geminiConversationTurn{ + Role: "user", + Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}}, + }) + require.True(t, conversationPrefixEqual(full, prefix)) + require.False(t, conversationPrefixEqual(prefix, full)) +} + +func TestSystemTextCompatible(t *testing.T) { + require.True(t, systemTextCompatible("You are helpful", "")) + require.True(t, systemTextCompatible("You are helpful", "You are helpful")) + require.False(t, systemTextCompatible("", "You are helpful")) + require.False(t, systemTextCompatible("You are helpful", "You are different")) +} + +type recordingUpstream struct { + doCalls int +} + +func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + r.doCalls++ + return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil +} + +func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) { + return r.Do(req, proxyURL, accountID, c) +} diff --git a/backend/internal/pkg/lspool/proxy_bridge.go b/backend/internal/pkg/lspool/proxy_bridge.go new file mode 100644 index 00000000..36f5348b --- /dev/null +++ b/backend/internal/pkg/lspool/proxy_bridge.go @@ -0,0 +1,268 @@ +package lspool + +import ( + "bufio" + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "golang.org/x/net/proxy" +) + +type lsProxyBridge struct { + listener net.Listener + server *http.Server + url string + upstream string +} + +type lsProxyBridgeManager struct { + mu sync.Mutex + bridges map[string]*lsProxyBridge + logger *slog.Logger +} + +var globalLSProxyBridgeManager = &lsProxyBridgeManager{ + bridges: make(map[string]*lsProxyBridge), + logger: slog.Default().With("component", "lspool-proxy-bridge"), +} + +var ( + lsProxyBridgeDialTimeout = 10 * time.Second + lsProxyBridgeProbeTargets = []string{ + "cloudcode-pa.googleapis.com:443", + "oauthaccountmanager.googleapis.com:443", + } +) + +func prepareLSProxyURL(raw string) (string, error) { + normalized, parsed, err := proxyurl.Parse(raw) + if err != nil { + return "", err + } + if parsed == nil { + return "", nil + } + + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + return normalized, nil + case "socks5", "socks5h": + return globalLSProxyBridgeManager.ensure(normalized, parsed) + default: + return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme) + } +} + +func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if bridge := m.bridges[key]; bridge != nil { + return bridge.url, nil + } + + bridge, err := newLSProxyBridge(upstream, m.logger) + if err != nil { + return "", err + } + m.bridges[key] = bridge + return bridge.url, nil +} + +func (m *lsProxyBridgeManager) closeAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for key, bridge := range m.bridges { + if bridge != nil { + _ = bridge.server.Close() + _ = bridge.listener.Close() + } + delete(m.bridges, key) + } +} + +func closeAllLSProxyBridgesForTest() { + globalLSProxyBridgeManager.closeAll() +} + +func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) { + dialer, err := proxy.FromURL(upstream, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("create SOCKS dialer: %w", err) + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("listen LS proxy bridge: %w", err) + } + + bridge := &lsProxyBridge{ + listener: listener, + url: "http://" + listener.Addr().String(), + upstream: upstream.Redacted(), + } + + server := &http.Server{ + Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)), + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 2 * time.Minute, + } + bridge.server = server + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err) + } + }() + + logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url) + go bridge.probeConnectivity(dialer, logger) + return bridge, nil +} + +func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "CONNECT only", http.StatusMethodNotAllowed) + return + } + + targetAddr := strings.TrimSpace(r.Host) + if targetAddr == "" { + targetAddr = strings.TrimSpace(r.URL.Host) + } + if targetAddr == "" { + http.Error(w, "missing target host", http.StatusBadRequest) + return + } + if _, _, err := net.SplitHostPort(targetAddr); err != nil { + targetAddr = net.JoinHostPort(targetAddr, "443") + } + + startedAt := time.Now() + logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr) + + dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout) + defer cancel() + + targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr) + if err != nil { + logger.Warn("LS proxy bridge dial failed", + "upstream", b.upstream, + "target", targetAddr, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond), + "err", err) + http.Error(w, "proxy dial failed", http.StatusBadGateway) + return + } + logger.Info("LS proxy bridge CONNECT established", + "upstream", b.upstream, + "target", targetAddr, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) + + hijacker, ok := w.(http.Hijacker) + if !ok { + _ = targetConn.Close() + http.Error(w, "hijack unsupported", http.StatusInternalServerError) + return + } + + clientConn, rw, err := hijacker.Hijack() + if err != nil { + _ = targetConn.Close() + return + } + + if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil { + _ = targetConn.Close() + _ = clientConn.Close() + return + } + + if rw != nil && rw.Reader.Buffered() > 0 { + if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil { + _ = targetConn.Close() + _ = clientConn.Close() + return + } + } + + tunnelConns(clientConn, targetConn) + } +} + +func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) { + if contextDialer, ok := dialer.(proxy.ContextDialer); ok { + return contextDialer.DialContext(ctx, "tcp", targetAddr) + } + + type dialResult struct { + conn net.Conn + err error + } + resultCh := make(chan dialResult, 1) + go func() { + conn, err := dialer.Dial("tcp", targetAddr) + resultCh <- dialResult{conn: conn, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.conn, result.err + } +} + +func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) { + for _, targetAddr := range lsProxyBridgeProbeTargets { + startedAt := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout) + conn, err := dialViaProxy(ctx, dialer, targetAddr) + cancel() + if err != nil { + logger.Warn("LS proxy bridge probe failed", + "upstream", b.upstream, + "target", targetAddr, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond), + "err", err) + continue + } + _ = conn.Close() + logger.Info("LS proxy bridge probe succeeded", + "upstream", b.upstream, + "target", targetAddr, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) + } +} + +func tunnelConns(clientConn net.Conn, targetConn net.Conn) { + var once sync.Once + closeBoth := func() { + _ = clientConn.Close() + _ = targetConn.Close() + } + + go func() { + _, _ = io.Copy(targetConn, clientConn) + once.Do(closeBoth) + }() + go func() { + _, _ = io.Copy(clientConn, targetConn) + once.Do(closeBoth) + }() +} + +func readConnectResponse(br *bufio.Reader) (*http.Response, error) { + return http.ReadResponse(br, &http.Request{Method: http.MethodConnect}) +} diff --git a/backend/internal/pkg/lspool/proxy_bridge_test.go b/backend/internal/pkg/lspool/proxy_bridge_test.go new file mode 100644 index 00000000..69960e3a --- /dev/null +++ b/backend/internal/pkg/lspool/proxy_bridge_test.go @@ -0,0 +1,193 @@ +package lspool + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) { + t.Cleanup(closeAllLSProxyBridgesForTest) + + got, err := prepareLSProxyURL("http://proxy.example.com:8080") + require.NoError(t, err) + require.Equal(t, "http://proxy.example.com:8080", got) +} + +func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) { + t.Cleanup(closeAllLSProxyBridgesForTest) + + targetAddr, closeTarget := startBridgeEchoServer(t) + defer closeTarget() + + socksURL, closeSOCKS := startBridgeSOCKS5Server(t) + defer closeSOCKS() + + bridgeURL, err := prepareLSProxyURL(socksURL) + require.NoError(t, err) + require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:")) + + // Same SOCKS upstream should reuse the same local bridge. + reusedURL, err := prepareLSProxyURL(socksURL) + require.NoError(t, err) + require.Equal(t, bridgeURL, reusedURL) + + bridgeAddr := strings.TrimPrefix(bridgeURL, "http://") + conn, err := net.Dial("tcp", bridgeAddr) + require.NoError(t, err) + defer conn.Close() + + _, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr) + require.NoError(t, err) + + reader := bufio.NewReader(conn) + resp, err := readConnectResponse(reader) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + _, err = conn.Write([]byte("ping")) + require.NoError(t, err) + + reply := make([]byte, 4) + _, err = io.ReadFull(reader, reply) + require.NoError(t, err) + require.Equal(t, "pong", string(reply)) +} + +func startBridgeEchoServer(t *testing.T) (string, func()) { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 4) + if _, err := io.ReadFull(c, buf); err != nil { + return + } + if string(buf) == "ping" { + _, _ = c.Write([]byte("pong")) + } + }(conn) + } + }() + + return ln.Addr().String(), func() { + _ = ln.Close() + <-done + } +} + +func startBridgeSOCKS5Server(t *testing.T) (string, func()) { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + for { + conn, err := ln.Accept() + if err != nil { + return + } + go handleBridgeSOCKS5Conn(conn) + } + }() + + return "socks5://" + ln.Addr().String(), func() { + _ = ln.Close() + <-done + } +} + +func handleBridgeSOCKS5Conn(conn net.Conn) { + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + _ = conn.Close() + return + } + methods := make([]byte, int(header[1])) + if _, err := io.ReadFull(conn, methods); err != nil { + _ = conn.Close() + return + } + _, _ = conn.Write([]byte{0x05, 0x00}) + + reqHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, reqHeader); err != nil { + _ = conn.Close() + return + } + if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 { + _ = conn.Close() + return + } + + targetHost, ok := readSOCKS5Addr(conn, reqHeader[3]) + if !ok { + _ = conn.Close() + return + } + portBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, portBuf); err != nil { + _ = conn.Close() + return + } + targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf)) + + targetConn, err := net.Dial("tcp", targetAddr) + if err != nil { + _, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + _ = conn.Close() + return + } + + _, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + tunnelConns(conn, targetConn) +} + +func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) { + switch atyp { + case 0x01: + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", false + } + return net.IP(buf).String(), true + case 0x03: + lenBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return "", false + } + buf := make([]byte, int(lenBuf[0])) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", false + } + return string(buf), true + case 0x04: + buf := make([]byte, 16) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", false + } + return net.IP(buf).String(), true + default: + return "", false + } +} diff --git a/backend/internal/pkg/lspool/proxy_exec.go b/backend/internal/pkg/lspool/proxy_exec.go new file mode 100644 index 00000000..38a678df --- /dev/null +++ b/backend/internal/pkg/lspool/proxy_exec.go @@ -0,0 +1,138 @@ +package lspool + +import ( + "fmt" + "net/url" + "os" + "os/exec" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" +) + +type lsLaunchPlan struct { + cmd *exec.Cmd + effectiveProxyURL string + proxyMode string + cleanup func() +} + +func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) { + normalized, parsed, err := proxyurl.Parse(rawProxyURL) + if err != nil { + return nil, err + } + + plan := &lsLaunchPlan{ + cmd: exec.Command(binPath, args...), + proxyMode: "direct", + } + + if parsed == nil { + return plan, nil + } + + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + plan.effectiveProxyURL = normalized + plan.proxyMode = "env-http-proxy" + return plan, nil + + case "socks5", "socks5h": + if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil { + cfgPath, err := writeProxychainsConfig(parsed) + if err != nil { + return nil, err + } + plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...) + plan.proxyMode = "proxychains4" + plan.cleanup = func() { + _ = os.Remove(cfgPath) + } + return plan, nil + } + + effectiveProxyURL, err := prepareLSProxyURL(normalized) + if err != nil { + return nil, err + } + plan.effectiveProxyURL = effectiveProxyURL + plan.proxyMode = "http-connect-bridge" + return plan, nil + + default: + return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme) + } +} + +func writeProxychainsConfig(proxyURL *url.URL) (string, error) { + content, err := buildProxychainsConfig(proxyURL) + if err != nil { + return "", err + } + + file, err := os.CreateTemp("", "sub2api-proxychains-*.conf") + if err != nil { + return "", fmt.Errorf("create proxychains config: %w", err) + } + + if _, err := file.WriteString(content); err != nil { + _ = file.Close() + _ = os.Remove(file.Name()) + return "", fmt.Errorf("write proxychains config: %w", err) + } + if err := file.Close(); err != nil { + _ = os.Remove(file.Name()) + return "", fmt.Errorf("close proxychains config: %w", err) + } + + return file.Name(), nil +} + +func buildProxychainsConfig(proxyURL *url.URL) (string, error) { + if proxyURL == nil { + return "", fmt.Errorf("proxy url is nil") + } + if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" { + return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme) + } + + host := strings.TrimSpace(proxyURL.Hostname()) + port := strings.TrimSpace(proxyURL.Port()) + if host == "" { + return "", fmt.Errorf("proxy host is empty") + } + if port == "" { + port = "1080" + } + + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") { + return "", fmt.Errorf("proxychains credentials cannot contain whitespace") + } + + var builder strings.Builder + builder.WriteString("strict_chain\n") + builder.WriteString("proxy_dns\n") + builder.WriteString("remote_dns_subnet 224\n") + builder.WriteString("tcp_connect_time_out 8000\n") + builder.WriteString("tcp_read_time_out 15000\n") + builder.WriteString("localnet 127.0.0.0/255.0.0.0\n") + builder.WriteString("localnet ::1/128\n") + builder.WriteString("[ProxyList]\n") + builder.WriteString("socks5 ") + builder.WriteString(host) + builder.WriteString(" ") + builder.WriteString(port) + if username != "" { + builder.WriteString(" ") + builder.WriteString(username) + if password != "" { + builder.WriteString(" ") + builder.WriteString(password) + } + } + builder.WriteString("\n") + return builder.String(), nil +} diff --git a/backend/internal/pkg/lspool/proxy_exec_test.go b/backend/internal/pkg/lspool/proxy_exec_test.go new file mode 100644 index 00000000..850edce7 --- /dev/null +++ b/backend/internal/pkg/lspool/proxy_exec_test.go @@ -0,0 +1,31 @@ +package lspool + +import ( + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) { + proxyURL, err := url.Parse("socks5h://gostuser:fastapipwd@216.167.85.31:1080") + require.NoError(t, err) + + cfg, err := buildProxychainsConfig(proxyURL) + require.NoError(t, err) + require.Contains(t, cfg, "proxy_dns\n") + require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n") + require.Contains(t, cfg, "localnet ::1/128\n") + require.Contains(t, cfg, "[ProxyList]\n") + require.Contains(t, cfg, "socks5 216.167.85.31 1080 gostuser fastapipwd\n") +} + +func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) { + proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080") + require.NoError(t, err) + + _, err = buildProxychainsConfig(proxyURL) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "whitespace")) +} diff --git a/backend/internal/pkg/lspool/remote_instance.go b/backend/internal/pkg/lspool/remote_instance.go new file mode 100644 index 00000000..5de4bde2 --- /dev/null +++ b/backend/internal/pkg/lspool/remote_instance.go @@ -0,0 +1,99 @@ +package lspool + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" +) + +func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) { + endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("X-Worker-Token", i.workerToken) + if mode == "json" { + req.Header.Set("Content-Type", "application/json") + } else { + req.Header.Set("Content-Type", "application/octet-stream") + } + + resp, err := i.client.Do(req) + if err != nil { + return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("worker rpc read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200)) + } + return respBody, nil +} + +func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) { + endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("X-Worker-Token", i.workerToken) + if mode == "json" { + req.Header.Set("Content-Type", "application/json") + } else { + req.Header.Set("Content-Type", "application/octet-stream") + } + + resp, err := i.client.Do(req) + if err != nil { + return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err) + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200)) + } + return resp, nil +} + +func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) { + base := url.URL{ + Scheme: "http", + Host: i.Address, + Path: path, + } + values := url.Values{} + values.Set("service", service) + values.Set("method", method) + values.Set("mode", mode) + if i.routingKey != "" { + values.Set("routing_key", i.routingKey) + } + base.RawQuery = values.Encode() + return base.String(), nil +} + +func marshalWorkerJSONBody(input any) ([]byte, error) { + if input == nil { + return []byte("{}"), nil + } + body, err := json.Marshal(input) + if err != nil { + return nil, err + } + return body, nil +} diff --git a/backend/internal/pkg/lspool/upstream_adapter.go b/backend/internal/pkg/lspool/upstream_adapter.go new file mode 100644 index 00000000..f212b724 --- /dev/null +++ b/backend/internal/pkg/lspool/upstream_adapter.go @@ -0,0 +1,1620 @@ +// Package lspool provides an HTTPUpstream adapter that routes +// streamGenerateContent requests through real Language Server instances. +// +// Flow: +// +// sub2api → LSPoolUpstream.Do() → StartCascade → SendUserCascadeMessage +// → LS internally calls cloudcode-pa (with authentic TLS fingerprint) +// → Poll GetCascadeTrajectory for incremental text +// → Format as SSE and stream back to sub2api service layer +// +// The model is extracted from the original request body, not hardcoded. +package lspool + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) + +// Upstream is the interface matching service.HTTPUpstream +type Upstream interface { + Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) + DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) +} + +// LSPoolUpstream wraps an existing HTTPUpstream and intercepts +// streamGenerateContent requests to route them through the LS pool. +type LSPoolUpstream struct { + pool Backend + fallback Upstream + logger *slog.Logger + sessionMu sync.Mutex + sessions map[string]*cascadeSessionState +} + +// NewLSPoolUpstream creates an LS pool upstream wrapper. +func NewLSPoolUpstream(pool Backend, fallback Upstream) *LSPoolUpstream { + return &LSPoolUpstream{ + pool: pool, + fallback: fallback, + logger: slog.Default().With("component", "lspool-upstream"), + sessions: make(map[string]*cascadeSessionState), + } +} + +const ( + userNamespaceHeader = "X-Sub2API-User-Key" + useAICreditsHeader = "X-Antigravity-Use-AI-Credits" + availableCreditsHeader = "X-Antigravity-Available-Credits" + minimumCreditAmountHeader = "X-Antigravity-Minimum-Credit-Amount" + sessionStateTTL = 30 * time.Minute + lsSendMessageTimeout = 20 * time.Second + lsModelConfigTimeout = 20 * time.Second +) + +var ( + errLSRouteDirect = errors.New("request should use direct upstream") + errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session") + errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted") + errLSModelMapPending = errors.New("model mapping not ready") +) + +type cascadeSessionState struct { + CascadeID string + SystemText string + History []geminiConversationTurn + UpdatedAt time.Time +} + +type geminiEnvelope struct { + Model string `json:"model"` + Request json.RawMessage `json:"request"` +} + +type geminiRequestPayload struct { + Contents []geminiWireContent `json:"contents"` + SystemInstruction *geminiWireContent `json:"systemInstruction,omitempty"` + GenerationConfig *geminiWireGenerationConfig `json:"generationConfig,omitempty"` + SessionID string `json:"sessionId,omitempty"` +} + +type geminiWireGenerationConfig struct { + ResponseModalities []string `json:"responseModalities,omitempty"` + ImageConfig json.RawMessage `json:"imageConfig,omitempty"` +} + +type geminiWireContent struct { + Role string `json:"role"` + Parts []geminiWirePart `json:"parts"` +} + +type geminiWirePart struct { + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + InlineData *geminiWireInlineData `json:"inlineData,omitempty"` + FunctionCall map[string]any `json:"functionCall,omitempty"` + FunctionResponse map[string]any `json:"functionResponse,omitempty"` +} + +type geminiWireInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type geminiParsedRequest struct { + Model string + SessionID string + SystemText string + Turns []geminiConversationTurn + ResponseModalities []string + HasImageConfig bool + HasUnsupported bool +} + +type geminiConversationTurn struct { + Role string + Parts []geminiConversationPart +} + +type geminiConversationPart struct { + Kind string + Text string + MimeType string + Data string +} + +type lsRouteDecision struct { + UseLS bool + Reason string +} + +type lsRequestTrace struct { + StartedAt time.Time + AccountID int64 + Model string + SessionIDHash string + Replica int + CascadeID string + NewSession bool + InflightAtAcquire int64 + TurnCount int + GetOrCreateDuration time.Duration + StartCascadeDuration time.Duration + BuildInputDuration time.Duration + SendMessageDuration time.Duration + FirstPollLatency time.Duration + FirstTextLatency time.Duration + PollCount int +} + +// Do routes streamGenerateContent through LS, everything else through fallback. +func (u *LSPoolUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) + + if !isStreamGenerate(req.URL.Path) { + return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) + } + body, err := snapshotRequestBody(req) + if err != nil { + return nil, fmt.Errorf("snapshot request body: %w", err) + } + if len(bytes.TrimSpace(body)) == 0 { + return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) + } + + resp, err := u.doViaLS(req, body, accountID, proxyURL) + if err != nil { + if shouldFallbackDirect(err) { + u.logger.Warn("[LS-POOL] LS fell back to direct", "account", accountID, "err", err) + req.Body = io.NopCloser(bytes.NewReader(body)) + return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) + } + return nil, err + } + return resp, nil +} + +// DoWithTLS — LS handles its own TLS, so profile is ignored for LS requests. +func (u *LSPoolUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) + + if !isStreamGenerate(req.URL.Path) { + return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) + } + body, err := snapshotRequestBody(req) + if err != nil { + return nil, fmt.Errorf("snapshot request body: %w", err) + } + if len(bytes.TrimSpace(body)) == 0 { + return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) + } + + resp, err := u.doViaLS(req, body, accountID, proxyURL) + if err != nil { + if shouldFallbackDirect(err) { + u.logger.Warn("[LS-POOL] LS fell back to direct+TLS", "account", accountID, "err", err) + req.Body = io.NopCloser(bytes.NewReader(body)) + return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) + } + return nil, err + } + return resp, nil +} + +func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64, proxyURL string) (*http.Response, error) { + accountKey := strconv.FormatInt(accountID, 10) + + if CurrentLSStrategy() != LSStrategyJSParity { + return u.forwardDirectWithKeepalive(req, body, accountKey, accountID, proxyURL) + } + + parsed, err := parseGeminiRequest(body) + if err != nil { + return u.forwardDirect(req, body, proxyURL, accountID, "parse request failed") + } + + decision := decideJSParityRoute(parsed, body) + if !decision.UseLS { + return u.forwardDirect(req, body, proxyURL, accountID, decision.Reason) + } + + resp, err := u.forwardChatViaLS(req, body, parsed, accountKey, accountID, proxyURL) + if err != nil { + if shouldFallbackDirect(err) { + return u.forwardDirect(req, body, proxyURL, accountID, err.Error()) + } + return nil, err + } + return resp, nil +} + +func shouldFallbackDirect(err error) bool { + return errors.Is(err, errLSRouteDirect) || errors.Is(err, errLSTranscriptDrift) +} + +func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { + // Start/reuse LS instance — keeps heartbeat alive, authenticates with + // cloudcode-pa, and refreshes model mapping. The LS process itself is NOT + // used as a proxy; we forward the original HTTP request directly to + // cloudcode-pa, bypassing Cascade entirely. This avoids the IDE agent + // system prompt that Cascade injects. + _, err := u.pool.GetOrCreate(accountKey, "", proxyURL) + if err != nil { + return nil, fmt.Errorf("get LS instance: %w", err) + } + + u.logger.Info("[LS-POOL] Forwarding via direct HTTP (LS keepalive active)", + "account", accountID, "path", req.URL.Path) + + return u.forwardDirect(req, body, proxyURL, accountID, "strategy=direct") +} + +func (u *LSPoolUpstream) forwardDirect(req *http.Request, body []byte, proxyURL string, accountID int64, reason string) (*http.Response, error) { + u.logger.Info("[LS-POOL] Forwarding via direct HTTP", + "account", accountID, + "path", req.URL.Path, + "reason", reason) + req.Header.Del(userNamespaceHeader) + req.Body = io.NopCloser(bytes.NewReader(body)) + return u.fallback.Do(req, proxyURL, accountID, 1) +} + +func (u *LSPoolUpstream) extractAndStripInternalHeaders(req *http.Request, accountKey string) { + if auth := req.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { + accessToken := strings.TrimPrefix(auth, "Bearer ") + refreshToken := req.Header.Get("X-Antigravity-Refresh-Token") + var expiresAt time.Time + if raw := req.Header.Get("X-Antigravity-Token-Expiry"); raw != "" { + if parsed, err := time.Parse(time.RFC3339, raw); err == nil { + expiresAt = parsed + } + } + u.pool.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt) + } + + useAICredits, hasUseAICredits := parseBoolHeader(req.Header.Get(useAICreditsHeader)) + availableCredits, hasAvailableCredits := parseOptionalInt32Header(req.Header.Get(availableCreditsHeader)) + minimumCreditAmount, hasMinimumCreditAmount := parseOptionalInt32Header(req.Header.Get(minimumCreditAmountHeader)) + if hasUseAICredits || hasAvailableCredits || hasMinimumCreditAmount { + u.pool.SetAccountModelCredits(accountKey, useAICredits, availableCredits, minimumCreditAmount) + } + + req.Header.Del("X-Antigravity-Refresh-Token") + req.Header.Del("X-Antigravity-Token-Expiry") + req.Header.Del(useAICreditsHeader) + req.Header.Del(availableCreditsHeader) + req.Header.Del(minimumCreditAmountHeader) +} + +func parseBoolHeader(raw string) (bool, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return false, false + } + val, err := strconv.ParseBool(raw) + if err != nil { + return false, false + } + return val, true +} + +func parseOptionalInt32Header(raw string) (*int32, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, false + } + val, err := strconv.ParseInt(raw, 10, 32) + if err != nil { + return nil, false + } + parsed := int32(val) + return &parsed, true +} + +func shortTraceID(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "none" + } + sum := sha256.Sum256([]byte(raw)) + return fmt.Sprintf("%x", sum[:4]) +} + +func durationMS(d time.Duration) int64 { + if d <= 0 { + return 0 + } + return d.Milliseconds() +} + +func (u *LSPoolUpstream) logTraceSummary(level slog.Level, msg string, trace *lsRequestTrace, extra ...any) { + if trace == nil { + u.logger.Log(context.Background(), level, msg, extra...) + return + } + args := []any{ + "account", trace.AccountID, + "model", trace.Model, + "session", trace.SessionIDHash, + "replica", trace.Replica, + "cascade", shortTraceID(trace.CascadeID), + "new_session", trace.NewSession, + "turns", trace.TurnCount, + "inflight", trace.InflightAtAcquire, + "get_or_create_ms", durationMS(trace.GetOrCreateDuration), + "start_cascade_ms", durationMS(trace.StartCascadeDuration), + "build_input_ms", durationMS(trace.BuildInputDuration), + "send_message_ms", durationMS(trace.SendMessageDuration), + "first_poll_ms", durationMS(trace.FirstPollLatency), + "first_token_ms", durationMS(trace.FirstTextLatency), + "polls", trace.PollCount, + "total_ms", durationMS(time.Since(trace.StartedAt)), + } + args = append(args, extra...) + u.logger.Log(context.Background(), level, msg, args...) +} + +func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed *geminiParsedRequest, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { + trace := &lsRequestTrace{ + StartedAt: time.Now(), + AccountID: accountID, + Model: parsed.Model, + SessionIDHash: shortTraceID(parsed.SessionID), + TurnCount: len(parsed.Turns), + } + + getOrCreateStartedAt := time.Now() + sessionKey := buildSessionCacheKey(accountID, userNamespace(req), parsed.SessionID) + inst, err := u.pool.GetOrCreate(accountKey, sessionKey, proxyURL) + if err != nil { + trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] get instance failed", trace, "err", err) + return nil, fmt.Errorf("get LS instance: %w", err) + } + trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) + trace.Replica = inst.Replica + if !inst.HasModelMappingReady() { + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace) + return nil, errLSModelMapPending + } + if !inst.AcquireConcurrency() { + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] instance busy", trace, + "err", fmt.Sprintf("ls instance busy for account %d", accountID), + "current_inflight", inst.ConcurrentCount(), + "max_inflight", maxConcurrencyPerInstance) + return nil, fmt.Errorf("ls instance busy for account %d", accountID) + } + trace.InflightAtAcquire = inst.ConcurrentCount() + + state := u.getSessionState(sessionKey) + if state != nil && !systemTextCompatible(state.SystemText, parsed.SystemText) { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript drift, routing direct", trace) + return nil, errLSTranscriptDrift + } + + cascadeID := "" + newSession := false + sendTurn := geminiConversationTurn{} + contextPrefix := "" + + switch { + case state == nil: + if len(parsed.Turns) == 0 { + inst.ReleaseConcurrency() + return nil, errLSRouteDirect + } + lastTurn := parsed.Turns[len(parsed.Turns)-1] + if lastTurn.Role != "user" { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] invalid first turn for LS, routing direct", trace) + return nil, errLSRouteDirect + } + sendTurn = lastTurn + contextPrefix = renderConversationContext(parsed.SystemText, parsed.Turns[:len(parsed.Turns)-1]) + startCascadeStartedAt := time.Now() + cascadeID, err = u.startCascade(inst) + trace.StartCascadeDuration = time.Since(startCascadeStartedAt) + if err != nil { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] start cascade failed", trace, "err", err) + return nil, err + } + newSession = true + case !conversationPrefixEqual(parsed.Turns, state.History): + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript prefix mismatch, routing direct", trace) + return nil, errLSTranscriptDrift + default: + delta := parsed.Turns[len(state.History):] + if len(delta) != 1 || delta[0].Role != "user" { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] unsupported transcript delta, routing direct", trace) + return nil, errLSRouteDirect + } + sendTurn = delta[0] + cascadeID = state.CascadeID + } + trace.NewSession = newSession + trace.CascadeID = cascadeID + + buildInputStartedAt := time.Now() + items, media, err := buildLSInputFromTurn(sendTurn, contextPrefix) + trace.BuildInputDuration = time.Since(buildInputStartedAt) + if err != nil { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] build input failed", trace, "err", err) + return nil, fmt.Errorf("build ls input: %w", err) + } + if len(items) == 0 && len(media) == 0 { + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] empty LS input, routing direct", trace) + return nil, errLSRouteDirect + } + sendReq := map[string]any{ + "metadata": buildLSRequestMetadata(), + "cascadeId": cascadeID, + "items": items, + "blocking": false, + } + if len(media) > 0 { + sendReq["media"] = media + } + if cfg := buildCascadeConfig(parsed.Model); cfg != nil { + sendReq["cascadeConfig"] = cfg + } + sendStartedAt := time.Now() + sendCtx, sendCancel := context.WithTimeout(req.Context(), lsSendMessageTimeout) + defer sendCancel() + if _, err := inst.CallUnaryJSON(sendCtx, LSService, "SendUserCascadeMessage", sendReq); err != nil { + trace.SendMessageDuration = time.Since(sendStartedAt) + if newSession { + u.cancelCascade(inst, cascadeID) + } + inst.ReleaseConcurrency() + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] send user message failed", trace, "err", err) + return nil, fmt.Errorf("send user cascade message: %w", err) + } + trace.SendMessageDuration = time.Since(sendStartedAt) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "Cache-Control": []string{"no-cache"}, + "X-Accel-Buffering": []string{"no"}, + }, + Body: pr, + Request: req, + } + + go func() { + defer inst.ReleaseConcurrency() + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + u.streamCascadeResponse(ctx, inst, cascadeID, pw, trace, func(finalText string) { + u.putSessionState(sessionKey, &cascadeSessionState{ + CascadeID: cascadeID, + SystemText: parsed.SystemText, + History: appendModelTurn(cloneConversationTurns(parsed.Turns), finalText), + UpdatedAt: time.Now(), + }) + }) + }() + + return resp, nil +} + +func (u *LSPoolUpstream) startCascade(inst *Instance) (string, error) { + resp, err := inst.CallUnaryJSON(context.Background(), LSService, "StartCascade", map[string]any{ + "metadata": buildLSRequestMetadata(), + }) + if err != nil { + return "", fmt.Errorf("start cascade: %w", err) + } + var decoded struct { + CascadeID string `json:"cascadeId"` + } + if err := json.Unmarshal(resp, &decoded); err != nil { + return "", fmt.Errorf("decode start cascade: %w", err) + } + if decoded.CascadeID == "" { + return "", errors.New("start cascade returned empty cascadeId") + } + return decoded.CascadeID, nil +} + +// cancelCascade tells the LS to stop processing a cascade invocation. +// Uses a short timeout — best-effort, don't block shutdown. +func (u *LSPoolUpstream) cancelCascade(inst *Instance, cascadeID string) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err := inst.CallUnaryJSON(ctx, LSService, "CancelCascadeInvocation", map[string]any{ + "cascadeId": cascadeID, + }) + if err != nil { + // Try force stop as fallback + _, _ = inst.CallUnaryJSON(ctx, LSService, "ForceStopCascadeTree", map[string]any{ + "cascadeId": cascadeID, + }) + } +} + +// streamCascadeResponse polls GetCascadeTrajectory with adaptive interval. +// Fast (50ms) when model is generating, slow (150ms) when waiting. +// We also issue an immediate first poll so the first token is not delayed by +// the initial ticker interval. +func (u *LSPoolUpstream) streamCascadeResponse(ctx context.Context, inst *Instance, cascadeID string, w *io.PipeWriter, trace *lsRequestTrace, onDone func(string)) { + const ( + fastInterval = 50 * time.Millisecond + slowInterval = 150 * time.Millisecond + maxDuration = 5 * time.Minute + maxIdleTimeout = 30 * time.Second + ) + + ticker := time.NewTicker(slowInterval) + defer ticker.Stop() + + timeout := time.After(maxDuration) + lastText := "" + generating := false + lastProgressAt := time.Time{} + + pollOnce := func() bool { + if trace != nil { + trace.PollCount++ + } + trajResp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeTrajectory", map[string]any{ + "cascadeId": cascadeID, + }) + if err != nil { + if ctx.Err() != nil { + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) + _ = w.Close() + return true + } + return false + } + if trace != nil && trace.FirstPollLatency == 0 { + trace.FirstPollLatency = time.Since(trace.StartedAt) + } + + state := extractPlannerResponseState(trajResp) + text, isGenerating, status := state.Text, state.Generating, state.Status + if state.ErrorMessage != "" { + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade terminated with error", trace, "error", state.ErrorMessage) + if isQuotaExhaustedError(state.ErrorMessage) { + _ = w.CloseWithError(errLSQuotaExhausted) + } else { + _ = w.CloseWithError(errors.New(state.ErrorMessage)) + } + return true + } + + // Adaptive interval: fast when generating, slow when idle. + if isGenerating && !generating { + ticker.Reset(fastInterval) + generating = true + } else if !isGenerating && generating { + ticker.Reset(slowInterval) + generating = false + } + + // Emit new text as SSE. + if text != lastText && len(text) > len(lastText) { + newPart := text[len(lastText):] + sseEvent := buildGeminiSSEChunk(newPart) + if _, err := w.Write([]byte(sseEvent)); err != nil { + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write SSE failed", trace, "err", err) + _ = w.CloseWithError(err) + return true + } + lastText = text + lastProgressAt = time.Now() + if trace != nil && trace.FirstTextLatency == 0 { + trace.FirstTextLatency = time.Since(trace.StartedAt) + } + } + + // Check if done. + if status == "CASCADE_RUN_STATUS_IDLE" && text != "" && !isGenerating { + usage := extractUsageFromTrajectory(trajResp) + if usage != nil { + finalEvent := buildGeminiSSEFinalChunk(usage) + if _, err := w.Write([]byte(finalEvent)); err != nil { + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write final SSE failed", trace, "err", err) + _ = w.CloseWithError(err) + return true + } + } + if onDone != nil { + onDone(lastText) + } + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request completed", trace) + _ = w.Close() + return true + } + + if !lastProgressAt.IsZero() && time.Since(lastProgressAt) > maxIdleTimeout { + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] No progress, stopping", trace) + _ = w.Close() + return true + } + + return false + } + + if pollOnce() { + return + } + + for { + select { + case <-ctx.Done(): + u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) + _ = w.Close() + return + case <-timeout: + u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade timeout", trace) + _ = w.Close() + return + case <-ticker.C: + if pollOnce() { + return + } + } + } +} + +// ============================================================ +// SSE builders — match Gemini v1internal:streamGenerateContent?alt=sse format +// ============================================================ + +func buildGeminiSSEChunk(text string) string { + // cloudcode-pa v1internal 格式: {"response": {"candidates": [...]}} + chunk := map[string]any{ + "response": map[string]any{ + "candidates": []map[string]any{ + { + "content": map[string]any{ + "parts": []map[string]string{{"text": text}}, + "role": "model", + }, + }, + }, + }, + } + data, _ := json.Marshal(chunk) + return "data: " + string(data) + "\n\n" +} + +func buildGeminiSSEFinalChunk(usage map[string]any) string { + chunk := map[string]any{ + "response": map[string]any{ + "candidates": []map[string]any{ + { + "content": map[string]any{ + "parts": []map[string]string{{"text": ""}}, + "role": "model", + }, + "finishReason": "STOP", + }, + }, + "usageMetadata": usage, + }, + } + data, _ := json.Marshal(chunk) + return "data: " + string(data) + "\n\n" +} + +// ============================================================ +// Trajectory parsing +// ============================================================ + +type cascadePlannerState struct { + Text string + Generating bool + Status string + ErrorMessage string +} + +func extractPlannerResponseState(trajResp []byte) cascadePlannerState { + var raw map[string]any + if err := json.Unmarshal(trajResp, &raw); err != nil { + return cascadePlannerState{} + } + state := cascadePlannerState{} + state.Status, _ = raw["status"].(string) + state.ErrorMessage = findCascadeErrorMessage(raw) + + traj, ok := raw["trajectory"].(map[string]any) + if !ok { + return state + } + steps, ok := traj["steps"].([]any) + if !ok { + return state + } + for _, s := range steps { + sm, ok := s.(map[string]any) + if !ok { + continue + } + if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { + continue + } + if sm["status"] == "CORTEX_STEP_STATUS_GENERATING" { + state.Generating = true + } + if pr, ok := sm["plannerResponse"].(map[string]any); ok { + if r, ok := pr["response"].(string); ok { + state.Text = r + } + } + } + return state +} + +func extractPlannerResponseText(trajResp []byte) (text string, generating bool, status string) { + state := extractPlannerResponseState(trajResp) + return state.Text, state.Generating, state.Status +} + +func findCascadeErrorMessage(value any) string { + switch v := value.(type) { + case map[string]any: + if msg := summarizeCascadeErrorMap(v); msg != "" { + return msg + } + for _, child := range v { + if msg := findCascadeErrorMessage(child); msg != "" { + return msg + } + } + case []any: + for _, child := range v { + if msg := findCascadeErrorMessage(child); msg != "" { + return msg + } + } + } + return "" +} + +func summarizeCascadeErrorMap(m map[string]any) string { + if full := cascadeStringField(m, "fullError"); full != "" { + return full + } + if user := cascadeStringField(m, "userErrorMessage"); user != "" { + return user + } + + _, hasErrorCode := m["errorCode"] + short := cascadeStringField(m, "shortError") + details := cascadeStringField(m, "details") + message := cascadeStringField(m, "message") + + if hasErrorCode { + parts := make([]string, 0, 3) + if short != "" { + parts = append(parts, short) + } + if message != "" && message != short { + parts = append(parts, message) + } + if details != "" && details != short && details != message { + parts = append(parts, details) + } + if len(parts) > 0 { + return strings.Join(parts, ": ") + } + return fmt.Sprintf("cascade error: %v", m["errorCode"]) + } + + if reason := cascadeStringField(m, "terminationReason"); strings.Contains(strings.ToUpper(reason), "ERROR") { + if message != "" { + return message + } + if short != "" { + return short + } + } + + return "" +} + +func cascadeStringField(m map[string]any, key string) string { + raw, ok := m[key] + if !ok { + return "" + } + str, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(str) +} + +func extractUsageFromTrajectory(trajResp []byte) map[string]any { + var raw map[string]any + if err := json.Unmarshal(trajResp, &raw); err != nil { + return nil + } + traj, ok := raw["trajectory"].(map[string]any) + if !ok { + return nil + } + steps, ok := traj["steps"].([]any) + if !ok { + return nil + } + for _, s := range steps { + sm, ok := s.(map[string]any) + if !ok { + continue + } + if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { + continue + } + meta, ok := sm["metadata"].(map[string]any) + if !ok { + continue + } + mu, ok := meta["modelUsage"].(map[string]any) + if !ok { + continue + } + input, _ := mu["inputTokens"].(string) + output, _ := mu["outputTokens"].(string) + inputN, _ := strconv.Atoi(input) + outputN, _ := strconv.Atoi(output) + if inputN > 0 || outputN > 0 { + return map[string]any{ + "promptTokenCount": inputN, + "candidatesTokenCount": outputN, + "totalTokenCount": inputN + outputN, + } + } + } + return nil +} + +// ============================================================ +// Request parsing — dynamic model, no hardcoding +// ============================================================ + +func parseGeminiRequest(body []byte) (*geminiParsedRequest, error) { + var envelope geminiEnvelope + if err := json.Unmarshal(body, &envelope); err != nil { + return nil, err + } + + reqBody := body + if len(envelope.Request) > 0 { + reqBody = envelope.Request + } + + var payload geminiRequestPayload + if err := json.Unmarshal(reqBody, &payload); err != nil { + return nil, err + } + + parsed := &geminiParsedRequest{ + Model: envelope.Model, + SessionID: payload.SessionID, + ResponseModalities: append([]string(nil), payload.GenerationConfig.GetResponseModalities()...), + HasImageConfig: payload.GenerationConfig != nil && len(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) > 0 && string(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) != "null", + } + if parsed.Model == "" { + var top map[string]json.RawMessage + if err := json.Unmarshal(body, &top); err == nil { + _ = json.Unmarshal(top["model"], &parsed.Model) + } + } + if payload.SystemInstruction != nil { + parsed.SystemText = collectTextParts(payload.SystemInstruction.Parts) + } + for _, content := range payload.Contents { + turn := geminiConversationTurn{Role: normalizeTurnRole(content.Role)} + for _, part := range content.Parts { + switch { + case part.Thought || part.ThoughtSignature != "": + parsed.HasUnsupported = true + case len(part.FunctionCall) > 0 || len(part.FunctionResponse) > 0: + parsed.HasUnsupported = true + case part.InlineData != nil: + turn.Parts = append(turn.Parts, geminiConversationPart{ + Kind: "media", + MimeType: part.InlineData.MimeType, + Data: part.InlineData.Data, + }) + case part.Text != "": + turn.Parts = append(turn.Parts, geminiConversationPart{ + Kind: "text", + Text: part.Text, + }) + } + } + if len(turn.Parts) > 0 { + parsed.Turns = append(parsed.Turns, turn) + } + } + + return parsed, nil +} + +func collectTextParts(parts []geminiWirePart) string { + var texts []string + for _, part := range parts { + if part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n") +} + +func (g *geminiWireGenerationConfig) GetResponseModalities() []string { + if g == nil { + return nil + } + return g.ResponseModalities +} + +func normalizeTurnRole(role string) string { + if strings.EqualFold(strings.TrimSpace(role), "model") { + return "model" + } + return "user" +} + +func decideJSParityRoute(parsed *geminiParsedRequest, body []byte) lsRouteDecision { + if parsed == nil { + return lsRouteDecision{Reason: "nil parsed request"} + } + if requestHasTools(body) { + return lsRouteDecision{Reason: "tools are not supported through cascade"} + } + if parsed.SessionID == "" { + return lsRouteDecision{Reason: "missing sessionId"} + } + if parsed.HasUnsupported { + return lsRouteDecision{Reason: "request contains unsupported Gemini parts"} + } + if isImageGenerationModelName(parsed.Model) { + return lsRouteDecision{Reason: "image generation model"} + } + if parsed.HasImageConfig { + return lsRouteDecision{Reason: "request has imageConfig"} + } + for _, modality := range parsed.ResponseModalities { + if strings.EqualFold(strings.TrimSpace(modality), "IMAGE") { + return lsRouteDecision{Reason: "responseModalities contains IMAGE"} + } + } + if len(parsed.Turns) == 0 { + return lsRouteDecision{Reason: "empty conversation"} + } + return lsRouteDecision{UseLS: true, Reason: "js-parity cascade chat"} +} + +func extractPromptAndModel(body []byte) (string, string) { + var outer map[string]json.RawMessage + if err := json.Unmarshal(body, &outer); err != nil { + return "", "" + } + var model string + if m, ok := outer["model"]; ok { + json.Unmarshal(m, &model) + } + if reqRaw, ok := outer["request"]; ok { + return extractPromptFromGeminiRequest(reqRaw), model + } + return extractPromptFromGeminiRequest(body), model +} + +func extractPromptFromGeminiRequest(data []byte) string { + var req struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } `json:"contents"` + SystemInstruction *struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"systemInstruction"` + } + if err := json.Unmarshal(data, &req); err != nil { + return "" + } + + var parts []string + + // Include system instruction if present + if req.SystemInstruction != nil { + for _, p := range req.SystemInstruction.Parts { + if p.Text != "" { + parts = append(parts, "[System]\n"+p.Text) + } + } + } + + // Include full conversation history + for _, c := range req.Contents { + role := c.Role + if role == "" { + role = "user" + } + for _, p := range c.Parts { + if p.Text != "" { + if role == "model" { + parts = append(parts, "[Assistant]\n"+p.Text) + } else { + parts = append(parts, "[User]\n"+p.Text) + } + } + } + } + + if len(parts) == 0 { + return "" + } + + // If only one part and no system instruction, return raw text (simple case) + if len(parts) == 1 && req.SystemInstruction == nil { + text := parts[0] + // Strip the [User]\n prefix for simple single-message case + if strings.HasPrefix(text, "[User]\n") { + return strings.TrimPrefix(text, "[User]\n") + } + return text + } + + return strings.Join(parts, "\n\n") +} + +func buildLSInputFromTurn(turn geminiConversationTurn, contextPrefix string) ([]map[string]any, []map[string]any, error) { + items := make([]map[string]any, 0, len(turn.Parts)+1) + media := make([]map[string]any, 0) + if strings.TrimSpace(contextPrefix) != "" { + items = append(items, map[string]any{"text": contextPrefix}) + } + for _, part := range turn.Parts { + switch part.Kind { + case "text": + if part.Text != "" { + items = append(items, map[string]any{"text": part.Text}) + } + case "media": + decoded, err := base64.StdEncoding.DecodeString(part.Data) + if err != nil { + return nil, nil, fmt.Errorf("decode inlineData: %w", err) + } + media = append(media, map[string]any{ + "mimeType": part.MimeType, + "inlineData": decoded, + }) + } + } + return items, media, nil +} + +func renderConversationContext(systemText string, turns []geminiConversationTurn) string { + var parts []string + if strings.TrimSpace(systemText) != "" { + parts = append(parts, "[System]\n"+strings.TrimSpace(systemText)) + } + for _, turn := range turns { + var rendered []string + for _, part := range turn.Parts { + switch part.Kind { + case "text": + if strings.TrimSpace(part.Text) != "" { + rendered = append(rendered, part.Text) + } + case "media": + label := "attachment" + switch { + case strings.HasPrefix(part.MimeType, "image/"): + label = "image attachment" + case strings.HasPrefix(part.MimeType, "video/"): + label = "video attachment" + case strings.HasPrefix(part.MimeType, "audio/"): + label = "audio attachment" + } + rendered = append(rendered, fmt.Sprintf("[%s: %s]", label, part.MimeType)) + } + } + if len(rendered) == 0 { + continue + } + roleLabel := "User" + if turn.Role == "model" { + roleLabel = "Assistant" + } + parts = append(parts, fmt.Sprintf("[%s]\n%s", roleLabel, strings.Join(rendered, "\n"))) + } + return strings.Join(parts, "\n\n") +} + +func buildCascadeConfig(model string) map[string]any { + normalizedModel := normalizeRequestedModelName(model) + if normalizedModel == "" { + return nil + } + modelEnum := resolveModelEnum(normalizedModel) + + return map[string]any{ + "plannerConfig": map[string]any{ + "requestedModel": map[string]any{ + "model": modelEnum, + }, + }, + } +} + +func buildLSRequestMetadata() map[string]any { + return map[string]any{ + "ideName": "antigravity", + "ideVersion": "1.107.0", + } +} + +func appendModelTurn(turns []geminiConversationTurn, modelText string) []geminiConversationTurn { + if strings.TrimSpace(modelText) == "" { + return turns + } + return append(turns, geminiConversationTurn{ + Role: "model", + Parts: []geminiConversationPart{{ + Kind: "text", + Text: modelText, + }}, + }) +} + +func cloneConversationTurns(src []geminiConversationTurn) []geminiConversationTurn { + out := make([]geminiConversationTurn, 0, len(src)) + for _, turn := range src { + copied := geminiConversationTurn{ + Role: turn.Role, + Parts: append([]geminiConversationPart(nil), turn.Parts...), + } + out = append(out, copied) + } + return out +} + +func conversationPrefixEqual(full, prefix []geminiConversationTurn) bool { + if len(prefix) > len(full) { + return false + } + for i := range prefix { + if prefix[i].Role != full[i].Role { + return false + } + if len(prefix[i].Parts) != len(full[i].Parts) { + return false + } + for j := range prefix[i].Parts { + if prefix[i].Parts[j] != full[i].Parts[j] { + return false + } + } + } + return true +} + +// ResolveModelEnumPublic is the exported version of resolveModelEnum for testing. +func ResolveModelEnumPublic(model string) int { + return resolveModelEnum(model) +} + +// resolveModelEnum maps a Gemini/Claude model name to its proto enum number. +// Priority: dynamic mapping (from LS) > static fallback. +// The LS uses MODEL_PLACEHOLDER_Mn enum values (1000+n) that are dynamically +// assigned by the server — only these are guaranteed to work. +func resolveModelEnum(model string) int { + model = normalizeRequestedModelName(model) + + // 1. Try dynamic mapping first (populated by RefreshModelMapping from LS) + dynamicModelMapMu.RLock() + // Exact match + if v, ok := dynamicModelMap[model]; ok { + dynamicModelMapMu.RUnlock() + return v + } + // Fuzzy match: normalized label vs model name + for label, v := range dynamicModelMap { + if labelMatchesModel(label, model) { + dynamicModelMapMu.RUnlock() + return v + } + } + // Prefix match in dynamic map + for label, v := range dynamicModelMap { + normalized := normalizeLabel(label) + if strings.HasPrefix(model, normalized) || strings.HasPrefix(normalized, model) { + dynamicModelMapMu.RUnlock() + return v + } + } + dynamicModelMapMu.RUnlock() + + // 2. Known working placeholders (verified on Mac with LS v1.107.0) + // These map display labels to MODEL_PLACEHOLDER_Mn enum values + knownPlaceholders := map[string]int{ + "gemini-3-flash": 1047, + "gemini-3.1-pro-high": 1037, + "gemini-3.1-pro-low": 1036, + "claude-sonnet-4-6-thinking": 1035, + "claude-opus-4-6-thinking": 1026, + "gpt-oss-120b-medium": 342, + } + if v, ok := knownPlaceholders[model]; ok { + return v + } + // Fuzzy match known placeholders + modelLower := strings.ToLower(model) + for k, v := range knownPlaceholders { + if strings.Contains(modelLower, strings.ToLower(k)) || strings.Contains(strings.ToLower(k), modelLower) { + return v + } + } + + // 3. Family-based fallback from known placeholders + for k, v := range knownPlaceholders { + if strings.Contains(modelLower, "claude") && strings.Contains(k, "claude") { + return v + } + if strings.Contains(modelLower, "gemini") && strings.Contains(k, "gemini") { + return v + } + if strings.Contains(modelLower, "gpt") && strings.Contains(k, "gpt") { + return v + } + } + + // 4. Also check dynamic map if available + dynamicModelMapMu.RLock() + defer dynamicModelMapMu.RUnlock() + for label, v := range dynamicModelMap { + labelLower := strings.ToLower(normalizeLabel(label)) + // Same family: "claude" matches "claude-*", "gemini" matches "gemini-*" + if strings.Contains(modelLower, "claude") && strings.Contains(labelLower, "claude") { + return v + } + if strings.Contains(modelLower, "gemini") && strings.Contains(labelLower, "gemini") { + return v + } + if strings.Contains(modelLower, "gpt") && strings.Contains(labelLower, "gpt") { + return v + } + } + + // Last resort: return first available model from dynamic map + for _, v := range dynamicModelMap { + return v + } + + // No dynamic mapping at all (LS not started yet?) — use gemini-2.5-flash static + return 312 +} + +// labelMatchesModel does fuzzy matching between LS display label and sub2api model name. +// e.g. "Gemini 3 Flash" matches "gemini-3-flash", "Claude Sonnet 4.6 (Thinking)" matches "claude-sonnet-4-6-thinking" +func labelMatchesModel(label, model string) bool { + normalize := func(s string) string { + s = strings.ToLower(s) + s = strings.ReplaceAll(s, " ", "-") + s = strings.ReplaceAll(s, ".", "-") + s = strings.ReplaceAll(s, "(", "") + s = strings.ReplaceAll(s, ")", "") + s = strings.ReplaceAll(s, "--", "-") + return strings.TrimRight(s, "-") + } + return normalize(label) == normalize(model) +} + +// Dynamic model mapping — refreshed from LS at startup +var ( + dynamicModelMapMu sync.RWMutex + dynamicModelMap = map[string]int{} // label -> enum value +) + +// HasDynamicModelMappingPublic is exported for testing. +func HasDynamicModelMappingPublic() bool { + return hasDynamicModelMapping() +} + +// hasDynamicModelMapping returns true if at least one model has been loaded from the LS. +func hasDynamicModelMapping() bool { + dynamicModelMapMu.RLock() + defer dynamicModelMapMu.RUnlock() + return len(dynamicModelMap) > 0 +} + +// RefreshModelMapping queries the LS for available models and builds the mapping. +// Called automatically when an LS instance starts. +func RefreshModelMapping(inst *Instance) bool { + if inst == nil { + return false + } + startedAt := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), lsModelConfigTimeout) + defer cancel() + + resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{}) + if err != nil { + inst.SetModelMappingReady(false) + slog.Warn("[LS-POOL] Failed to get model config", + "account", inst.AccountID, + "replica", inst.Replica, + "address", inst.Address, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond), + "err", err) + return false + } + + var data struct { + ClientModelConfigs []struct { + Label string `json:"label"` + ModelOrAlias map[string]any `json:"modelOrAlias"` + } `json:"clientModelConfigs"` + } + if err := json.Unmarshal(resp, &data); err != nil { + inst.SetModelMappingReady(false) + return false + } + + newMap := make(map[string]int) + for _, cfg := range data.ClientModelConfigs { + label := cfg.Label + if label == "" { + continue + } + // modelOrAlias is {"model": "MODEL_PLACEHOLDER_M37"} in JSON + modelStr, _ := cfg.ModelOrAlias["model"].(string) + if modelStr == "" { + continue + } + // Parse "MODEL_PLACEHOLDER_M37" → 1037 + enumVal := parseModelEnumString(modelStr) + if enumVal > 0 { + // Store both the display label and a normalized form + newMap[label] = enumVal + // Also store kebab-case version: "Gemini 3 Flash" → "gemini-3-flash" + normalized := normalizeLabel(label) + if normalized != "" { + newMap[normalized] = enumVal + } + } + } + + if len(newMap) > 0 { + dynamicModelMapMu.Lock() + dynamicModelMap = newMap + dynamicModelMapMu.Unlock() + inst.SetModelMappingReady(true) + slog.Info("[LS-POOL] Model mapping refreshed", + "account", inst.AccountID, + "replica", inst.Replica, + "address", inst.Address, + "count", len(newMap)/2, + "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) + return true + } + inst.SetModelMappingReady(false) + return false +} + +func parseModelEnumString(s string) int { + // Named enums + named := map[string]int{ + "MODEL_CLAUDE_4_SONNET": 281, + "MODEL_CLAUDE_4_SONNET_THINKING": 282, + "MODEL_CLAUDE_4_OPUS": 290, + "MODEL_CLAUDE_4_OPUS_THINKING": 291, + "MODEL_CLAUDE_4_5_SONNET": 333, + "MODEL_CLAUDE_4_5_SONNET_THINKING": 334, + "MODEL_CLAUDE_4_5_HAIKU": 340, + "MODEL_CLAUDE_4_5_HAIKU_THINKING": 341, + "MODEL_OPENAI_GPT_OSS_120B_MEDIUM": 342, + "MODEL_GOOGLE_GEMINI_2_5_FLASH": 312, + "MODEL_GOOGLE_GEMINI_2_5_FLASH_THINKING": 313, + "MODEL_GOOGLE_GEMINI_2_5_FLASH_LITE": 330, + "MODEL_GOOGLE_GEMINI_2_5_PRO": 246, + } + if v, ok := named[s]; ok { + return v + } + // "MODEL_PLACEHOLDER_M37" → 1037 + if strings.HasPrefix(s, "MODEL_PLACEHOLDER_M") { + numStr := strings.TrimPrefix(s, "MODEL_PLACEHOLDER_M") + n, err := strconv.Atoi(numStr) + if err == nil { + return 1000 + n + } + } + return 0 +} + +func normalizeLabel(label string) string { + s := strings.ToLower(label) + s = strings.ReplaceAll(s, " ", "-") + s = strings.ReplaceAll(s, ".", "-") + s = strings.ReplaceAll(s, "(", "") + s = strings.ReplaceAll(s, ")", "") + s = strings.ReplaceAll(s, "--", "-") + return strings.TrimRight(s, "-") +} + +func normalizeRequestedModelName(model string) string { + normalized := strings.ToLower(strings.TrimSpace(model)) + normalized = strings.TrimPrefix(normalized, "models/") + return normalized +} + +func isGeminiPlannerModel(model string) bool { + return strings.Contains(normalizeRequestedModelName(model), "gemini") +} + +func systemTextCompatible(stored, current string) bool { + stored = strings.TrimSpace(stored) + current = strings.TrimSpace(current) + return current == "" || current == stored +} + +// ============================================================ +// Helpers +// ============================================================ + +func buildSessionCacheKey(accountID int64, namespace, sessionID string) string { + return fmt.Sprintf("%d:%s:%s", accountID, namespace, sessionID) +} + +func userNamespace(req *http.Request) string { + if req == nil { + return "anon" + } + for _, value := range []string{ + req.Header.Get(userNamespaceHeader), + req.Header.Get("X-Api-Key"), + req.Header.Get("X-Goog-Api-Key"), + req.Header.Get("Authorization"), + } { + if strings.TrimSpace(value) != "" { + sum := sha256.Sum256([]byte(value)) + return fmt.Sprintf("%x", sum[:8]) + } + } + return "anon" +} + +func (u *LSPoolUpstream) getSessionState(key string) *cascadeSessionState { + u.sessionMu.Lock() + defer u.sessionMu.Unlock() + u.pruneExpiredSessionsLocked() + state := u.sessions[key] + if state == nil { + return nil + } + cloned := &cascadeSessionState{ + CascadeID: state.CascadeID, + SystemText: state.SystemText, + History: cloneConversationTurns(state.History), + UpdatedAt: state.UpdatedAt, + } + return cloned +} + +func (u *LSPoolUpstream) putSessionState(key string, state *cascadeSessionState) { + if state == nil { + return + } + u.sessionMu.Lock() + defer u.sessionMu.Unlock() + u.pruneExpiredSessionsLocked() + u.sessions[key] = &cascadeSessionState{ + CascadeID: state.CascadeID, + SystemText: state.SystemText, + History: cloneConversationTurns(state.History), + UpdatedAt: state.UpdatedAt, + } +} + +func (u *LSPoolUpstream) pruneExpiredSessionsLocked() { + now := time.Now() + for key, state := range u.sessions { + if state == nil || now.Sub(state.UpdatedAt) > sessionStateTTL { + delete(u.sessions, key) + } + } +} + +func isStreamGenerate(path string) bool { + return strings.Contains(path, "streamGenerateContent") +} + +// isQuotaExhaustedError detects 429 QUOTA_EXHAUSTED errors from LS cascade trajectory. +// When detected, the caller should fall back to direct HTTP so the gateway can +// inject enabledCreditTypes for AI Credits retry. +func isQuotaExhaustedError(msg string) bool { + lower := strings.ToLower(msg) + return (strings.Contains(lower, "resource_exhausted") || strings.Contains(lower, "quota_exhausted")) && + (strings.Contains(lower, "429") || strings.Contains(lower, "exhausted your capacity")) +} + +func isImageGenerationModelName(model string) bool { + modelLower := normalizeRequestedModelName(model) + return modelLower == "gemini-3.1-flash-image" || + modelLower == "gemini-3.1-flash-image-preview" || + strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") || + modelLower == "gemini-3-pro-image" || + modelLower == "gemini-3-pro-image-preview" || + strings.HasPrefix(modelLower, "gemini-3-pro-image-") || + modelLower == "gemini-2.5-flash-image" || + modelLower == "gemini-2.5-flash-image-preview" || + strings.HasPrefix(modelLower, "gemini-2.5-flash-image-") +} + +// requestHasTools checks if the Gemini request body contains tools/function declarations. +// These are not supported through the Cascade path and must use direct HTTP. +func requestHasTools(body []byte) bool { + // Check both the wrapped format {"request": {"tools": [...]}} and direct {"tools": [...]} + var outer map[string]json.RawMessage + if err := json.Unmarshal(body, &outer); err != nil { + return false + } + + // Check in wrapped request + if reqRaw, ok := outer["request"]; ok { + var inner map[string]json.RawMessage + if json.Unmarshal(reqRaw, &inner) == nil { + if tools, ok := inner["tools"]; ok && len(tools) > 4 { // > "[]" or "null" + return true + } + } + } + + // Check at top level + if tools, ok := outer["tools"]; ok && len(tools) > 4 { + return true + } + return false +} + +func snapshotRequestBody(req *http.Request) ([]byte, error) { + if req.Body == nil { + return nil, nil + } + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(body)) + return body, nil +} + +// unused but needed for compilation +var _ sync.Mutex diff --git a/backend/internal/pkg/lspool/worker_manager.go b/backend/internal/pkg/lspool/worker_manager.go new file mode 100644 index 00000000..3a33f085 --- /dev/null +++ b/backend/internal/pkg/lspool/worker_manager.go @@ -0,0 +1,649 @@ +package lspool + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +const ( + lsWorkerManagedByLabel = "managed-by" + lsWorkerManagedByValue = "sub2api" + lsWorkerAccountLabel = "account_id" + lsWorkerProxyHashLabel = "proxy_hash" + lsWorkerImageTagLabel = "image_tag" + lsWorkerControlPort = 18081 +) + +type workerManagerConfig struct { + Image string + Network string + DockerSocket string + IdleTTL time.Duration + MaxActive int + StartupTimeout time.Duration + RequestTimeout time.Duration +} + +type dockerClient interface { + ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) + ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error + ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) + ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error + ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error + Close() error +} + +type workerManager struct { + cfg workerManagerConfig + docker dockerClient + http *http.Client + + mu sync.Mutex + workers map[string]*workerHandle + state map[string]*workerAccountState + + ctx context.Context + cancel context.CancelFunc + logger *slog.Logger +} + +type workerHandle struct { + Key string + AccountID string + ProxyURL string + ProxyHash string + ContainerID string + Container string + Address string + AuthToken string + LastUsed time.Time + LastStateSHA string +} + +type workerAccountState struct { + HasToken bool `json:"has_token"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + HasModelCredits bool `json:"has_model_credits"` + UseAICredits bool `json:"use_ai_credits"` + AvailableCredits *int32 `json:"available_credits,omitempty"` + MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"` +} + +func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) { + if cfg == nil { + return nil, fmt.Errorf("config is nil") + } + + managerCfg := workerManagerConfig{ + Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image), + Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network), + DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket), + IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL, + MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive, + StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout, + RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout, + } + + if managerCfg.Image == "" { + managerCfg.Image = "weishaw/sub2api-lsworker:latest" + } + if managerCfg.Network == "" { + managerCfg.Network = "sub2api-network" + } + if managerCfg.DockerSocket == "" { + managerCfg.DockerSocket = "unix:///var/run/docker.sock" + } + if managerCfg.IdleTTL <= 0 { + managerCfg.IdleTTL = 15 * time.Minute + } + if managerCfg.MaxActive < 1 { + managerCfg.MaxActive = 50 + } + if managerCfg.StartupTimeout <= 0 { + managerCfg.StartupTimeout = 45 * time.Second + } + if managerCfg.RequestTimeout <= 0 { + managerCfg.RequestTimeout = 60 * time.Second + } + + opts := []client.Opt{client.WithAPIVersionNegotiation()} + if managerCfg.DockerSocket != "" { + opts = append(opts, client.WithHost(managerCfg.DockerSocket)) + } else { + opts = append(opts, client.FromEnv) + } + + dockerClient, err := client.NewClientWithOpts(opts...) + if err != nil { + return nil, fmt.Errorf("create docker client: %w", err) + } + + return newWorkerManager(managerCfg, dockerClient) +} + +func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) { + ctx, cancel := context.WithCancel(context.Background()) + mgr := &workerManager{ + cfg: cfg, + docker: docker, + http: &http.Client{ + Timeout: cfg.RequestTimeout, + Transport: &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConnsPerHost: 8, + }, + }, + workers: make(map[string]*workerHandle), + state: make(map[string]*workerAccountState), + ctx: ctx, + cancel: cancel, + logger: slog.Default().With("component", "lspool-worker-manager"), + } + if err := mgr.reconcileManagedContainers(ctx); err != nil { + cancel() + _ = docker.Close() + return nil, err + } + go mgr.cleanupLoop() + return mgr, nil +} + +func (m *workerManager) Close() { + m.cancel() + + m.mu.Lock() + workers := make([]*workerHandle, 0, len(m.workers)) + for _, handle := range m.workers { + workers = append(workers, handle) + } + m.workers = make(map[string]*workerHandle) + m.mu.Unlock() + + for _, handle := range workers { + m.removeWorkerContainer(context.Background(), handle) + } + if m.docker != nil { + _ = m.docker.Close() + } +} + +func (m *workerManager) Stats() map[string]any { + m.mu.Lock() + defer m.mu.Unlock() + return map[string]any{ + "accounts": len(m.state), + "total": len(m.workers), + "active": len(m.workers), + } +} + +func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + state := m.ensureStateLocked(accountID) + state.HasToken = true + state.AccessToken = accessToken + state.RefreshToken = refreshToken + if expiresAt.IsZero() { + state.ExpiresAt = nil + } else { + ts := expiresAt.UTC() + state.ExpiresAt = &ts + } +} + +func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) { + m.mu.Lock() + defer m.mu.Unlock() + state := m.ensureStateLocked(accountID) + state.HasModelCredits = true + state.UseAICredits = useAICredits + state.AvailableCredits = cloneInt32Ptr(availableCredits) + state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage) +} + +func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) { + rawProxy := "" + if len(proxyURL) > 0 { + rawProxy = proxyURL[0] + } + normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy) + if err != nil { + return nil, err + } + if parsedProxy == nil { + return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID) + } + + replica := replicaSlotIndex(routingKey, parseLSReplicaCount()) + proxyHash := proxyHash(normalizedProxy) + workerKey := buildWorkerKey(accountID, proxyHash) + + m.mu.Lock() + state := cloneWorkerAccountState(m.state[accountID]) + if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" { + m.mu.Unlock() + return nil, fmt.Errorf("ls worker missing access token for account %s", accountID) + } + + handle := m.workers[workerKey] + if handle == nil { + if len(m.workers) >= m.cfg.MaxActive { + m.mu.Unlock() + return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive) + } + handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy) + if err != nil { + m.mu.Unlock() + return nil, err + } + m.workers[workerKey] = handle + } + handle.LastUsed = time.Now() + m.mu.Unlock() + + if err := m.waitForWorkerHealthy(handle); err != nil { + return nil, err + } + if err := m.syncWorkerState(handle, state); err != nil { + return nil, err + } + if err := m.waitForWorkerReady(handle, routingKey); err != nil { + return nil, err + } + + inst := &Instance{ + AccountID: accountID, + Replica: replica, + Address: handle.Address, + client: m.http, + healthy: true, + lastUsed: time.Now(), + modelMapReady: 1, + remote: true, + workerToken: handle.AuthToken, + routingKey: routingKey, + } + return inst, nil +} + +func (m *workerManager) cleanupLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.collectIdleWorkers() + } + } +} + +func (m *workerManager) collectIdleWorkers() { + now := time.Now() + var expired []*workerHandle + + m.mu.Lock() + for key, handle := range m.workers { + if handle == nil { + delete(m.workers, key) + continue + } + if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL { + continue + } + expired = append(expired, handle) + delete(m.workers, key) + } + m.mu.Unlock() + + for _, handle := range expired { + m.removeWorkerContainer(context.Background(), handle) + } +} + +func (m *workerManager) reconcileManagedContainers(ctx context.Context) error { + args := filters.NewArgs() + args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue)) + + containers, err := m.docker.ContainerList(ctx, container.ListOptions{ + All: true, + Filters: args, + }) + if err != nil { + return fmt.Errorf("list managed ls workers: %w", err) + } + + for _, summary := range containers { + handle := &workerHandle{ + ContainerID: summary.ID, + Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"), + } + if err := m.removeWorkerContainer(ctx, handle); err != nil { + return err + } + } + return nil +} + +func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) { + containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8]) + authToken := generateUUID() + + proxyHost := parsedProxy.Hostname() + proxyPort := parsedProxy.Port() + if proxyPort == "" { + proxyPort = "1080" + } + proxyUser := parsedProxy.User.Username() + proxyPass, _ := parsedProxy.User.Password() + + labels := map[string]string{ + lsWorkerManagedByLabel: lsWorkerManagedByValue, + lsWorkerAccountLabel: accountID, + lsWorkerProxyHashLabel: proxyHash, + lsWorkerImageTagLabel: m.cfg.Image, + } + + env := []string{ + "ANTIGRAVITY_APP_ROOT=/app/ls", + fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID), + fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken), + fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort), + fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"), + fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL), + fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost), + fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort), + fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser), + fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass), + fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort), + fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()), + } + if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" { + env = append(env, "TZ="+tz) + } + + createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{ + Image: m.cfg.Image, + Labels: labels, + Env: env, + }, &container.HostConfig{ + CapAdd: []string{"NET_ADMIN"}, + }, &network.NetworkingConfig{ + EndpointsConfig: map[string]*network.EndpointSettings{ + m.cfg.Network: {}, + }, + }, nil, containerName) + if err != nil { + return nil, fmt.Errorf("create ls worker container: %w", err) + } + + if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil { + _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) + return nil, fmt.Errorf("start ls worker container: %w", err) + } + + inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID) + if err != nil { + _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) + return nil, fmt.Errorf("inspect ls worker container: %w", err) + } + + address, err := workerAddressFromInspect(inspect, m.cfg.Network) + if err != nil { + _ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true}) + return nil, err + } + + m.logger.Info("created ls worker", + "account", shortAccountID(accountID), + "container", containerName, + "address", address, + "proxy_hash", proxyHash[:8]) + + return &workerHandle{ + Key: buildWorkerKey(accountID, proxyHash), + AccountID: accountID, + ProxyURL: proxyURL, + ProxyHash: proxyHash, + ContainerID: createResp.ID, + Container: containerName, + Address: address, + AuthToken: authToken, + LastUsed: time.Now(), + }, nil +} + +func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) { + if inspect.NetworkSettings == nil { + return "", fmt.Errorf("ls worker inspect missing network settings") + } + if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { + return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil + } + for _, endpoint := range inspect.NetworkSettings.Networks { + if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" { + return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil + } + } + return "", fmt.Errorf("ls worker missing IP address on network %s", networkName) +} + +func firstContainerName(names []string) string { + if len(names) == 0 { + return "" + } + return names[0] +} + +func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error { + ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) + defer cancel() + + for { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil) + if err != nil { + return err + } + req.Header.Set("X-Worker-Token", handle.AuthToken) + resp, err := m.http.Do(req) + if err == nil { + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + } + + select { + case <-ctx.Done(): + return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err()) + case <-time.After(500 * time.Millisecond): + } + } +} + +func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error { + ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout) + defer cancel() + + values := url.Values{} + if strings.TrimSpace(routingKey) != "" { + values.Set("routing_key", routingKey) + } + + for { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil) + if err != nil { + return err + } + req.Header.Set("X-Worker-Token", handle.AuthToken) + resp, err := m.http.Do(req) + if err == nil { + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + if len(body) > 0 { + m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200)) + } + } + + select { + case <-ctx.Done(): + return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err()) + case <-time.After(500 * time.Millisecond): + } + } +} + +func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error { + if state == nil { + return fmt.Errorf("ls worker state is nil") + } + body, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal worker state: %w", err) + } + + sum := fmt.Sprintf("%x", sha256.Sum256(body)) + if handle.LastStateSHA == sum { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Worker-Token", handle.AuthToken) + + resp, err := m.http.Do(req) + if err != nil { + return fmt.Errorf("sync worker state: %w", err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200)) + } + handle.LastStateSHA = sum + return nil +} + +func workerURL(handle *workerHandle, path string, values url.Values) string { + base := url.URL{ + Scheme: "http", + Host: handle.Address, + Path: path, + } + if values != nil { + base.RawQuery = values.Encode() + } + return base.String() +} + +func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error { + if handle == nil || strings.TrimSpace(handle.ContainerID) == "" { + return nil + } + timeout := 5 + _ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout}) + if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil { + return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err) + } + return nil +} + +func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState { + state := m.state[accountID] + if state == nil { + state = &workerAccountState{} + m.state[accountID] = state + } + return state +} + +func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) { + resolved := resolveLSProxy(proxyURL) + normalized, parsed, err := proxyurl.Parse(resolved) + if err != nil { + return "", nil, err + } + if parsed == nil { + return "", nil, nil + } + switch strings.ToLower(parsed.Scheme) { + case "socks5", "socks5h": + return normalized, parsed, nil + default: + return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme) + } +} + +func proxyHash(proxyURL string) string { + if strings.TrimSpace(proxyURL) == "" { + return "direct" + } + sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL))) + return fmt.Sprintf("%x", sum[:]) +} + +func buildWorkerKey(accountID, proxyHash string) string { + return accountID + ":" + proxyHash +} + +func cloneInt32Ptr(v *int32) *int32 { + if v == nil { + return nil + } + cp := *v + return &cp +} + +func cloneWorkerAccountState(state *workerAccountState) *workerAccountState { + if state == nil { + return nil + } + cp := *state + cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits) + cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount) + if state.ExpiresAt != nil { + ts := *state.ExpiresAt + cp.ExpiresAt = &ts + } + return &cp +} diff --git a/backend/internal/pkg/lspool/worker_manager_test.go b/backend/internal/pkg/lspool/worker_manager_test.go new file mode 100644 index 00000000..36008a15 --- /dev/null +++ b/backend/internal/pkg/lspool/worker_manager_test.go @@ -0,0 +1,266 @@ +package lspool + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/network" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/require" +) + +type fakeDockerClient struct { + mu sync.Mutex + + listResp []container.Summary + listCalls int + createCalls int + startCalls int + stopCalls int + removeCalls int + inspectCalls int + removedIDs []string + createdConfigs []*container.Config + inspectResp container.InspectResponse +} + +func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.listCalls++ + return append([]container.Summary(nil), f.listResp...), nil +} + +func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.createCalls++ + f.createdConfigs = append(f.createdConfigs, cfg) + return container.CreateResponse{ID: "worker-created"}, nil +} + +func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + f.mu.Lock() + defer f.mu.Unlock() + f.startCalls++ + return nil +} + +func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.inspectCalls++ + return f.inspectResp, nil +} + +func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + f.mu.Lock() + defer f.mu.Unlock() + f.stopCalls++ + return nil +} + +func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + f.mu.Lock() + defer f.mu.Unlock() + f.removeCalls++ + f.removedIDs = append(f.removedIDs, containerID) + return nil +} + +func (f *fakeDockerClient) Close() error { return nil } + +func TestResolveWorkerProxyRejectsHTTP(t *testing.T) { + _, _, err := resolveWorkerProxy("http://127.0.0.1:7890") + require.Error(t, err) + require.Contains(t, err.Error(), "only supports socks5/socks5h") +} + +func TestProxyHashUsesNormalizedProxy(t *testing.T) { + normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080") + require.NoError(t, err) + require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized) + + hash1 := proxyHash(normalized) + hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080") + require.Equal(t, hash1, hash2) +} + +func TestWorkerManagerRequiresToken(t *testing.T) { + fakeDocker := &fakeDockerClient{} + manager, err := newWorkerManager(workerManagerConfig{ + Image: "worker:latest", + Network: "sub2api-network", + DockerSocket: "unix:///var/run/docker.sock", + IdleTTL: time.Minute, + MaxActive: 2, + StartupTimeout: time.Second, + RequestTimeout: time.Second, + }, fakeDocker) + require.NoError(t, err) + defer manager.Close() + + _, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080") + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} + +func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) { + var mu sync.Mutex + var healthCalls int + var readyCalls int + var stateCalls int + var stateBodies [][]byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/healthz": + mu.Lock() + healthCalls++ + mu.Unlock() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + case "/readyz": + mu.Lock() + readyCalls++ + mu.Unlock() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ready")) + case "/account/state": + body, _ := io.ReadAll(r.Body) + mu.Lock() + stateCalls++ + stateBodies = append(stateBodies, body) + mu.Unlock() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + fakeDocker := &fakeDockerClient{} + manager, err := newWorkerManager(workerManagerConfig{ + Image: "worker:latest", + Network: "sub2api-network", + DockerSocket: "unix:///var/run/docker.sock", + IdleTTL: time.Minute, + MaxActive: 4, + StartupTimeout: time.Second, + RequestTimeout: time.Second, + }, fakeDocker) + require.NoError(t, err) + defer manager.Close() + + accountID := "9" + proxyURL := "socks5h://user:pass@127.0.0.1:1080" + hash := proxyHash(proxyURL) + key := buildWorkerKey(accountID, hash) + + manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour)) + manager.mu.Lock() + manager.workers[key] = &workerHandle{ + Key: key, + AccountID: accountID, + ProxyURL: proxyURL, + ProxyHash: hash, + ContainerID: "existing-worker", + Container: "sub2api-ls-9-test", + Address: strings.TrimPrefix(server.URL, "http://"), + AuthToken: "worker-token", + LastUsed: time.Now(), + } + manager.mu.Unlock() + + inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL) + require.NoError(t, err) + require.True(t, inst1.remote) + require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica) + + inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL) + require.NoError(t, err) + require.True(t, inst2.remote) + + mu.Lock() + defer mu.Unlock() + require.GreaterOrEqual(t, healthCalls, 2) + require.GreaterOrEqual(t, readyCalls, 2) + require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged") + require.Len(t, stateBodies, 1) + + var synced workerAccountState + require.NoError(t, json.Unmarshal(stateBodies[0], &synced)) + require.True(t, synced.HasToken) + require.Equal(t, "ya29.test", synced.AccessToken) +} + +func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) { + fakeDocker := &fakeDockerClient{} + manager, err := newWorkerManager(workerManagerConfig{ + Image: "worker:latest", + Network: "sub2api-network", + DockerSocket: "unix:///var/run/docker.sock", + IdleTTL: time.Minute, + MaxActive: 1, + StartupTimeout: time.Second, + RequestTimeout: time.Second, + }, fakeDocker) + require.NoError(t, err) + defer manager.Close() + + manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour)) + manager.mu.Lock() + manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()} + manager.mu.Unlock() + + _, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080") + require.Error(t, err) + require.Contains(t, err.Error(), "limit reached") + require.Equal(t, 0, fakeDocker.createCalls) +} + +func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) { + fakeDocker := &fakeDockerClient{ + listResp: []container.Summary{ + { + ID: "old-worker-1", + Names: []string{"/sub2api-ls-9-deadbeef"}, + }, + { + ID: "old-worker-2", + Names: []string{"/sub2api-ls-10-beadfeed"}, + }, + }, + } + + manager, err := newWorkerManager(workerManagerConfig{ + Image: "worker:latest", + Network: "sub2api-network", + DockerSocket: "unix:///var/run/docker.sock", + IdleTTL: time.Minute, + MaxActive: 4, + StartupTimeout: time.Second, + RequestTimeout: time.Second, + }, fakeDocker) + require.NoError(t, err) + defer manager.Close() + + require.Equal(t, 1, fakeDocker.listCalls) + require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs) +} + +func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) { + fakeDocker := &fakeDockerClient{} + _, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()}) + require.NoError(t, err) +} diff --git a/backend/internal/pkg/lspool/worker_server.go b/backend/internal/pkg/lspool/worker_server.go new file mode 100644 index 00000000..9de98eab --- /dev/null +++ b/backend/internal/pkg/lspool/worker_server.go @@ -0,0 +1,368 @@ +package lspool + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "strconv" + "strings" + "sync" + "time" +) + +type WorkerServerConfig struct { + AccountID string + AuthToken string + ListenAddr string + AppRoot string + NetworkReadyFile string + MaxIdleTime time.Duration + HealthInterval time.Duration +} + +type WorkerServer struct { + cfg WorkerServerConfig + pool *Pool + logger *slog.Logger + + mu sync.RWMutex + state workerAccountState +} + +func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) { + if strings.TrimSpace(cfg.AccountID) == "" { + return nil, fmt.Errorf("worker account id is required") + } + if strings.TrimSpace(cfg.AuthToken) == "" { + return nil, fmt.Errorf("worker auth token is required") + } + if strings.TrimSpace(cfg.ListenAddr) == "" { + cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort) + } + if strings.TrimSpace(cfg.AppRoot) == "" { + cfg.AppRoot = "/app/ls" + } + if cfg.MaxIdleTime <= 0 { + cfg.MaxIdleTime = 15 * time.Minute + } + if cfg.HealthInterval <= 0 { + cfg.HealthInterval = 30 * time.Second + } + + poolCfg := DefaultConfig() + poolCfg.AppRoot = cfg.AppRoot + poolCfg.MaxIdleTime = cfg.MaxIdleTime + poolCfg.HealthCheckInterval = cfg.HealthInterval + + return &WorkerServer{ + cfg: cfg, + pool: NewPool(poolCfg), + logger: slog.Default().With("component", "lsworker"), + }, nil +} + +func NewWorkerServerFromEnv() (*WorkerServer, error) { + maxIdleTime := 15 * time.Minute + if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" { + if parsed, err := time.ParseDuration(raw); err == nil { + maxIdleTime = parsed + } + } + healthInterval := 30 * time.Second + if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" { + if parsed, err := time.ParseDuration(raw); err == nil { + healthInterval = parsed + } + } + + return NewWorkerServer(WorkerServerConfig{ + AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")), + AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")), + ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")), + AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")), + NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")), + MaxIdleTime: maxIdleTime, + HealthInterval: healthInterval, + }) +} + +func (s *WorkerServer) Close() { + if s.pool != nil { + s.pool.Close() + } +} + +func (s *WorkerServer) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/healthz", s.handleHealthz) + mux.HandleFunc("/readyz", s.handleReadyz) + mux.HandleFunc("/account/state", s.handleAccountState) + mux.HandleFunc("/rpc/unary", s.handleRPCUnary) + mux.HandleFunc("/rpc/stream", s.handleRPCStream) + return mux +} + +func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) { + if !s.authorize(w, r) { + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) +} + +func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) { + if !s.authorize(w, r) { + return + } + routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key")) + inst, err := s.ensureReady(r.Context(), routingKey) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica))) +} + +func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) { + if !s.authorize(w, r) { + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + defer r.Body.Close() + var payload workerAccountState + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "invalid account state payload", http.StatusBadRequest) + return + } + + s.mu.Lock() + s.state = *cloneWorkerAccountState(&payload) + s.mu.Unlock() + s.applyState() + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) +} + +func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) { + if !s.authorize(w, r) { + return + } + service, method, mode, routingKey, ok := parseRPCRequest(w, r) + if !ok { + return + } + + inst, err := s.ensureReady(r.Context(), routingKey) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read request body failed", http.StatusBadRequest) + return + } + if len(body) == 0 { + body = []byte("{}") + } + + var respBody []byte + switch mode { + case "json": + var input any + if err := json.Unmarshal(body, &input); err != nil { + http.Error(w, "invalid json rpc body", http.StatusBadRequest) + return + } + respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input) + case "proto": + respBody, err = inst.CallRPC(r.Context(), service, method, body) + default: + http.Error(w, "unsupported rpc mode", http.StatusBadRequest) + return + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write(respBody) +} + +func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) { + if !s.authorize(w, r) { + return + } + service, method, mode, routingKey, ok := parseRPCRequest(w, r) + if !ok { + return + } + + inst, err := s.ensureReady(r.Context(), routingKey) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read request body failed", http.StatusBadRequest) + return + } + + var resp *http.Response + switch mode { + case "json": + var input any + if len(body) == 0 { + body = []byte("{}") + } + if err := json.Unmarshal(body, &input); err != nil { + http.Error(w, "invalid json rpc body", http.StatusBadRequest) + return + } + resp, err = inst.StreamRPCJSON(r.Context(), service, method, input) + case "proto": + resp, err = inst.StreamRPC(r.Context(), service, method, body) + default: + http.Error(w, "unsupported rpc mode", http.StatusBadRequest) + return + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + +func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool { + if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) { + return true + } + http.Error(w, "unauthorized", http.StatusUnauthorized) + return false +} + +func subtleHeaderEqual(left, right string) bool { + if left == "" || right == "" { + return false + } + return left == right +} + +func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return "", "", "", "", false + } + query := r.URL.Query() + service = strings.TrimSpace(query.Get("service")) + method = strings.TrimSpace(query.Get("method")) + mode = strings.ToLower(strings.TrimSpace(query.Get("mode"))) + routingKey = strings.TrimSpace(query.Get("routing_key")) + if service == "" || method == "" { + http.Error(w, "missing rpc target", http.StatusBadRequest) + return "", "", "", "", false + } + if mode == "" { + mode = "proto" + } + return service, method, mode, routingKey, true +} + +func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) { + if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" { + if _, err := os.Stat(path); err != nil { + return nil, fmt.Errorf("worker network not ready: %w", err) + } + } + + s.applyState() + s.mu.RLock() + state := cloneWorkerAccountState(&s.state) + s.mu.RUnlock() + if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" { + return nil, fmt.Errorf("worker access token not configured") + } + + inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "") + if err != nil { + return nil, err + } + if inst.HasModelMappingReady() { + return inst, nil + } + + modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout) + defer cancel() + _ = modelCtx + if !RefreshModelMapping(inst) { + return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica) + } + return inst, nil +} + +func (s *WorkerServer) applyState() { + s.mu.RLock() + state := cloneWorkerAccountState(&s.state) + s.mu.RUnlock() + if state == nil { + return + } + if state.HasToken { + expiresAt := time.Time{} + if state.ExpiresAt != nil { + expiresAt = state.ExpiresAt.UTC() + } + s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt) + } + if state.HasModelCredits { + s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount) + } +} + +func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: listenAddr, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + } +} + +func workerExitCode(err error) int { + if err == nil { + return 0 + } + return 1 +} + +func parseWorkerControlPort() int { + raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT")) + if raw == "" { + return lsWorkerControlPort + } + port, err := strconv.Atoi(raw) + if err != nil || port < 1 { + return lsWorkerControlPort + } + return port +} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 4309e997..20e0c421 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -13,6 +13,8 @@ import ( "strings" "sync" "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/lspool" "time" "github.com/andybalholm/brotli" @@ -99,16 +101,34 @@ type httpUpstreamService struct { // NewHTTPUpstream 创建通用 HTTP 上游服务 // 使用配置中的连接池参数构建 Transport // +// 当环境变量 ANTIGRAVITY_LS_MODE=true 时,自动包装 LS 池拦截层: +// - 仅对已知兼容的 LS 请求形态启用转发 +// - 对普通 streamGenerateContent 请求保留原有直连路径,避免误送到不兼容的 LS RPC +// // 参数: // - cfg: 全局配置,包含连接池参数和隔离策略 // // 返回: // - service.HTTPUpstream 接口实现 func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { - return &httpUpstreamService{ + base := &httpUpstreamService{ cfg: cfg, clients: make(map[string]*upstreamClientEntry), } + + // LS 池模式: 包装一层拦截, streamGenerateContent 走 LS + if lspool.IsLSModeEnabled() { + pool := lspool.GlobalPool(cfg) + if pool != nil { + slog.Info("LS pool mode enabled — streamGenerateContent will route through Language Server", + "component", "http_upstream") + return lspool.NewLSPoolUpstream(pool, base) + } + slog.Warn("LS pool mode enabled but pool is nil — falling back to direct mode", + "component", "http_upstream") + } + + return base } // Do 执行 HTTP 请求 diff --git a/backend/internal/service/antigravity_credits_overages.go b/backend/internal/service/antigravity_credits_overages.go index ec365085..00f6f414 100644 --- a/backend/internal/service/antigravity_credits_overages.go +++ b/backend/internal/service/antigravity_credits_overages.go @@ -104,6 +104,10 @@ func classifyAntigravity429(body []byte) antigravity429Category { return antigravity429QuotaExhausted } } + if strings.Contains(lowerBody, "exhausted your capacity on this model") && + strings.Contains(lowerBody, "quota will reset after") { + return antigravity429QuotaExhausted + } if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted { return antigravity429RateLimited } diff --git a/backend/internal/service/antigravity_credits_overages_test.go b/backend/internal/service/antigravity_credits_overages_test.go index 7a5224da..7793f38a 100644 --- a/backend/internal/service/antigravity_credits_overages_test.go +++ b/backend/internal/service/antigravity_credits_overages_test.go @@ -21,6 +21,16 @@ func TestClassifyAntigravity429(t *testing.T) { require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body)) }) + t.Run("模型配额耗尽文案也视为可切 AI Credits", func(t *testing.T) { + body := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s." + } + }`) + require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body)) + }) + t.Run("结构化限流", func(t *testing.T) { body := []byte(`{ "error": { @@ -146,6 +156,68 @@ func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits") } +func TestHandleSmartRetry_ModelQuotaMessage_UsesCredits(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 151, + Name: "acc-151", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Extra: map[string]any{ + "allow_overages": true, + }, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-opus-4-6": "claude-opus-4-6", + }, + }, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s." + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"model":"claude-opus-4-6","request":{}}`), + httpUpstream: upstream, + accountRepo: repo, + requestedModel: "claude-opus-4-6", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"}) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp) + require.Len(t, upstream.requestBodies, 1) + require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes") +} + func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) { successResp := &http.Response{ StatusCode: http.StatusOK, diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index a76e59fb..10a84dca 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -112,6 +112,144 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, return nil, false } +// injectLSPoolHeaders adds internal headers carrying OAuth credentials for the +// LS pool layer. These headers are consumed and stripped by LSPoolUpstream +// before the request reaches the Language Server. When LS mode is disabled, +// these headers are harmless — the direct upstream never sees them because +// they are stripped inside LSPoolUpstream.Do(). In direct mode the request +// goes straight through httpUpstreamService.Do() which doesn't inspect them. +func injectLSPoolHeaders(req *http.Request, account *Account) { + if req == nil || account == nil { + return + } + if rt, ok := account.Credentials["refresh_token"].(string); ok && rt != "" { + req.Header.Set("X-Antigravity-Refresh-Token", rt) + } + if ea, ok := account.Credentials["expires_at"].(string); ok && ea != "" { + req.Header.Set("X-Antigravity-Token-Expiry", ea) + } + req.Header.Set("X-Antigravity-Use-AI-Credits", strconv.FormatBool(account.IsOveragesEnabled())) + + availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account) + if availableCredits != nil { + req.Header.Set("X-Antigravity-Available-Credits", strconv.FormatInt(int64(*availableCredits), 10)) + } + if minimumCreditAmount != nil { + req.Header.Set("X-Antigravity-Minimum-Credit-Amount", strconv.FormatInt(int64(*minimumCreditAmount), 10)) + } +} + +func resolveLSPoolModelCreditsState(account *Account) (*int32, *int32) { + if account == nil || account.Extra == nil { + minimum := int32(50) + return nil, &minimum + } + + var availableCredits *int32 + var minimumCreditAmount *int32 + + collect := func(entry map[string]any) { + if entry == nil { + return + } + if !isGoogleOneAICreditsEntry(entry) { + return + } + if availableCredits == nil { + if parsed, ok := parseAICreditsInt32(firstPresent(entry, "Amount", "amount", "creditAmount")); ok { + availableCredits = &parsed + } + } + if minimumCreditAmount == nil { + if parsed, ok := parseAICreditsInt32(firstPresent(entry, "MinimumBalance", "minimum_balance", "minimumCreditAmountForUsage")); ok { + minimumCreditAmount = &parsed + } + } + } + + if rawCredits, ok := account.Extra["ai_credits"].([]any); ok { + for _, item := range rawCredits { + if entry, ok := item.(map[string]any); ok { + collect(entry) + } + } + } + + if loadCodeAssist, ok := account.Extra["load_code_assist"].(map[string]any); ok { + if paidTier, ok := loadCodeAssist["paidTier"].(map[string]any); ok { + if credits, ok := paidTier["availableCredits"].([]any); ok { + for _, item := range credits { + if entry, ok := item.(map[string]any); ok { + collect(entry) + } + } + } + } + } + + if minimumCreditAmount == nil { + defaultMinimum := int32(50) + minimumCreditAmount = &defaultMinimum + } + return availableCredits, minimumCreditAmount +} + +func isGoogleOneAICreditsEntry(entry map[string]any) bool { + creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string) + creditType = strings.TrimSpace(strings.ToUpper(creditType)) + return creditType == "" || creditType == "GOOGLE_ONE_AI" +} + +func firstPresent(entry map[string]any, keys ...string) any { + for _, key := range keys { + if value, ok := entry[key]; ok { + return value + } + } + return nil +} + +func parseAICreditsInt32(raw any) (int32, bool) { + switch v := raw.(type) { + case int: + return int32(v), true + case int32: + return v, true + case int64: + return int32(v), true + case float32: + return int32(v), true + case float64: + return int32(v), true + case json.Number: + parsed, err := v.Int64() + if err != nil { + floatVal, floatErr := strconv.ParseFloat(v.String(), 64) + if floatErr != nil { + return 0, false + } + return int32(floatVal), true + } + return int32(parsed), true + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return 0, false + } + parsed, err := strconv.ParseInt(trimmed, 10, 32) + if err == nil { + return int32(parsed), true + } + floatVal, floatErr := strconv.ParseFloat(trimmed, 64) + if floatErr != nil { + return 0, false + } + return int32(floatVal), true + default: + return 0, false + } +} + // PromptTooLongError 表示上游明确返回 prompt too long type PromptTooLongError struct { StatusCode int @@ -305,6 +443,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } } + injectLSPoolHeaders(retryReq, p.account) retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts) @@ -489,6 +628,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) break } + injectLSPoolHeaders(retryReq, p.account) retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { @@ -627,6 +767,7 @@ urlFallbackLoop: if err != nil { return nil, err } + injectLSPoolHeaders(upstreamReq, p.account) // Capture upstream request body for ops retry of this attempt. if p.c != nil && len(p.body) > 0 { @@ -1289,9 +1430,19 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) { } // wrapV1InternalRequest 包装请求为 v1internal 格式 -func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) { +func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) { + sessionID := "" + if len(preferredSessionID) > 0 { + sessionID = preferredSessionID[0] + } + + bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID) + if err != nil { + return nil, fmt.Errorf("补全 sessionId 失败: %w", err) + } + var request any - if err := json.Unmarshal(originalBody, &request); err != nil { + if err := json.Unmarshal(bodyWithSessionID, &request); err != nil { return nil, fmt.Errorf("解析请求体失败: %w", err) } @@ -2156,7 +2307,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } // 包装请求 - wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody) + wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID) if err != nil { return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request") } @@ -2220,10 +2371,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if fallbackModel != "" && fallbackModel != mappedModel { logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody) + fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID) if err == nil { fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped) if err == nil { + injectLSPoolHeaders(fallbackReq, account) fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) if err == nil && fallbackResp.StatusCode < 400 { _ = resp.Body.Close() @@ -2263,7 +2415,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID) cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody) - retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody) + retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID) if wrapErr == nil { retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 1eb1451e..8cb49d24 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -600,6 +600,63 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing require.Equal(t, mappedModel, result.UpstreamModel) } +func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("session_id", "session-header-1") + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + upstream := &queuedHTTPUpstreamStub{ + responses: []*http.Response{ + { + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-session-1"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + }, + }, + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 16, + Name: "acc-gemini-session", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, upstream.requestBodies, 1) + + var wrapped map[string]any + require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped)) + requestNode, ok := wrapped["request"].(map[string]any) + require.True(t, ok) + require.Equal(t, "session-header-1", requestNode["sessionId"]) +} + func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 0a096257..413c3065 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -2500,6 +2500,57 @@ "supports_vision": true, "supports_web_search": true }, + "gemini-3-flash": { + "cache_read_input_token_cost": 5e-08, + "cache_read_input_token_cost_priority": 9e-08, + "input_cost_per_audio_token": 1e-06, + "input_cost_per_audio_token_priority": 1.8e-06, + "input_cost_per_token": 5e-07, + "input_cost_per_token_priority": 9e-07, + "litellm_provider": "vertex_ai-language-models", + "max_audio_length_hours": 8.4, + "max_audio_per_prompt": 1, + "max_images_per_prompt": 3000, + "max_input_tokens": 1048576, + "max_output_tokens": 65535, + "max_pdf_size_mb": 30, + "max_tokens": 65535, + "max_video_length": 1, + "max_videos_per_prompt": 10, + "mode": "chat", + "output_cost_per_reasoning_token": 3e-06, + "output_cost_per_token": 3e-06, + "output_cost_per_token_priority": 5.4e-06, + "source": "https://ai.google.dev/pricing/gemini-3", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/completions", + "/v1/batch" + ], + "supported_modalities": [ + "text", + "image", + "audio", + "video" + ], + "supported_output_modalities": [ + "text" + ], + "supports_audio_output": false, + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_url_context": true, + "supports_vision": true, + "supports_web_search": true + }, "gemini-3-flash-preview": { "cache_read_input_token_cost": 5e-08, "cache_read_input_token_cost_priority": 9e-08, diff --git a/deploy/.env.example b/deploy/.env.example index e1eb8256..e0126bcb 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,6 +20,9 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# Main application image override +SUB2API_IMAGE=zfc931912343/sub2api:latest + # ----------------------------------------------------------------------------- # Logging Configuration # 日志配置 @@ -389,3 +392,32 @@ OPS_ENABLED=true # Leave empty for direct connection (recommended for overseas servers) # 留空表示直连(适用于海外服务器) UPDATE_PROXY_URL= + +# ----------------------------------------------------------------------------- +# Language Server Pool Mode (Enhanced Security) +# ----------------------------------------------------------------------------- +# Enable to route requests through real AntiGravity LS binary +# Makes upstream traffic indistinguishable from real IDE +# ANTIGRAVITY_LS_MODE=true +# LS replicas per account. Default is 5. +# Increase for higher concurrency, but each replica is an extra LS process. +# ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=5 +# Optional global fallback proxy for accounts without dedicated LS proxy. +# Must be socks5/socks5h in worker mode. +ANTIGRAVITY_LS_PROXY= +# LS routing strategy (default js-parity) +ANTIGRAVITY_LS_STRATEGY=js-parity +# Dynamic LS worker container image. Build/pull this image before enabling LS mode. +GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=zfc931912343/sub2api-lsworker:latest +# Docker network name shared by sub2api and dynamic ls-worker containers. +GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=sub2api-network +# Docker socket used by sub2api to create dynamic ls-worker containers. +GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=unix:///var/run/docker.sock +# Idle TTL before worker container is reaped. +GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=15m +# Maximum number of active worker containers on this node. +GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=50 +# Maximum time allowed for worker cold start and readiness. +GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=45s +# Per-request timeout when sub2api talks to worker control API. +GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=60s diff --git a/deploy/README.md b/deploy/README.md index dd311721..479df3f7 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -65,6 +65,20 @@ docker compose -f docker-compose.local.yml logs sub2api | grep "admin password" # http://localhost:8080 ``` +### LS Worker Image + +When `ANTIGRAVITY_LS_MODE=true`, Sub2API creates dynamic `ls-worker` +containers through the Docker socket. Build or pull the worker image before +enabling LS mode: + +```bash +cd /path/to/sub2api +docker build -f deploy/lsworker.Dockerfile -t weishaw/sub2api-lsworker:latest . +``` + +The `sub2api` container must also be able to access `/var/run/docker.sock`, +and the shared Docker network name must remain fixed at `sub2api-network`. + ### Method 2: Manual Deployment If you prefer manual control: diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8f60acd5..cd6e704e 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -283,6 +283,30 @@ gateway: queue: 0.7 error_rate: 0.8 ttft: 0.5 + # Antigravity LS worker container configuration + # Antigravity LS worker 容器控制平面配置 + antigravity_ls_worker: + # Worker image used by sub2api to create dynamic LS containers + # sub2api 用于创建动态 LS worker 的镜像 + image: "weishaw/sub2api-lsworker:latest" + # Docker network name shared by sub2api and workers + # sub2api 与 worker 共享的 Docker network 名称 + network: "sub2api-network" + # Docker socket path or host used by sub2api control plane + # sub2api 控制面访问的 Docker socket / host + docker_socket: "unix:///var/run/docker.sock" + # Idle TTL before a worker container is recycled + # worker 容器空闲回收时间 + idle_ttl: 15m + # Max active worker containers per node + # 单节点最大 worker 容器数量 + max_active: 50 + # Worker cold-start timeout + # worker 冷启动超时 + startup_timeout: 45s + # Timeout for control-plane calls from sub2api to worker + # sub2api 调用 worker 控制接口的超时 + request_timeout: 60s # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 5aea78fb..902740cb 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -36,6 +36,7 @@ services: volumes: # Local directory mapping for easy migration - ./data:/app/data + - /var/run/docker.sock:/var/run/docker.sock # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: # - ./config.yaml:/app/data/config.yaml @@ -128,6 +129,22 @@ services: - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} + # ======================================================================= + # Language Server Worker Mode + # ======================================================================= + - ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false} + - ANTIGRAVITY_APP_ROOT=/app/ls + - ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-} + - ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity} + - ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5} + - GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest} + - GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network} + - GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock} + - GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m} + - GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50} + - GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s} + - GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s} + # ======================================================================= # Security Configuration (URL Allowlist) # ======================================================================= @@ -230,4 +247,5 @@ services: # ============================================================================= networks: sub2api-network: + name: sub2api-network driver: bridge diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index a0bc1a60..bb213c76 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -16,7 +16,8 @@ services: # Sub2API Application # =========================================================================== sub2api: - image: weishaw/sub2api:latest + # Override with SUB2API_IMAGE to use a private registry or pinned tag. + image: ${SUB2API_IMAGE:-weishaw/sub2api:latest} container_name: sub2api restart: unless-stopped ulimits: @@ -28,6 +29,7 @@ services: volumes: # Data persistence (config.yaml will be auto-generated here) - sub2api_data:/app/data + - /var/run/docker.sock:/var/run/docker.sock # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: # - ./config.yaml:/app/data/config.yaml @@ -120,6 +122,26 @@ services: - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} + # ======================================================================= + # Language Server Pool Mode (Enhanced Security) + # ======================================================================= + # Enable to route requests through real LS binary (Google's own code) + # This makes upstream traffic indistinguishable from real IDE + - ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false} + - ANTIGRAVITY_APP_ROOT=/app/ls + # SOCKS5/HTTP proxy fallback used when account has no dedicated LS proxy + - ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-} + - ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity} + - ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5} + # Keep the worker image aligned with the main image release when overriding. + - GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest} + - GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network} + - GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock} + - GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m} + - GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50} + - GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s} + - GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s} + # ======================================================================= # Security Configuration (URL Allowlist) # ======================================================================= @@ -234,4 +256,5 @@ volumes: # ============================================================================= networks: sub2api-network: + name: sub2api-network driver: bridge diff --git a/deploy/docker-entrypoint.sh b/deploy/docker-entrypoint.sh index 47ab6bf1..c93c27ac 100644 --- a/deploy/docker-entrypoint.sh +++ b/deploy/docker-entrypoint.sh @@ -8,9 +8,27 @@ if [ "$(id -u)" = "0" ]; then mkdir -p /app/data # Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro) chown -R sub2api:sub2api /app/data 2>/dev/null || true + if [ -S /var/run/docker.sock ]; then + DOCKER_GID="$(stat -c '%g' /var/run/docker.sock 2>/dev/null || true)" + if [ -n "${DOCKER_GID}" ]; then + DOCKER_GROUP="$(getent group "${DOCKER_GID}" | cut -d: -f1 || true)" + if [ -z "${DOCKER_GROUP}" ]; then + DOCKER_GROUP="dockersock" + groupadd -for -g "${DOCKER_GID}" "${DOCKER_GROUP}" 2>/dev/null || true + fi + usermod -aG "${DOCKER_GROUP}" sub2api 2>/dev/null || true + fi + fi # Re-invoke this script as sub2api so the flag-detection below # also runs under the correct user. - exec su-exec sub2api "$0" "$@" + # Use gosu if available (Debian), fall back to su-exec (Alpine) + if command -v gosu >/dev/null 2>&1; then + exec gosu sub2api "$0" "$@" + elif command -v su-exec >/dev/null 2>&1; then + exec su-exec sub2api "$0" "$@" + else + exec su -s /bin/sh sub2api -c "exec $0 $*" + fi fi # Compatibility: if the first arg looks like a flag (e.g. --help), diff --git a/deploy/ls-bin/cert.pem b/deploy/ls-bin/cert.pem new file mode 100644 index 00000000..de9c01de --- /dev/null +++ b/deploy/ls-bin/cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL +BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy +MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN +MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO +QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ +KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM +lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw +4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA +onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5 +/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD +k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9 +MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt +yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/ +Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh +bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk +Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt +q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai +KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6 +/Q== +-----END CERTIFICATE----- diff --git a/deploy/ls-bin/language_server_linux_arm b/deploy/ls-bin/language_server_linux_arm new file mode 100755 index 00000000..5463ca16 Binary files /dev/null and b/deploy/ls-bin/language_server_linux_arm differ diff --git a/deploy/ls-bin/language_server_linux_x64 b/deploy/ls-bin/language_server_linux_x64 new file mode 100755 index 00000000..79131863 Binary files /dev/null and b/deploy/ls-bin/language_server_linux_x64 differ diff --git a/deploy/lsworker-entrypoint.sh b/deploy/lsworker-entrypoint.sh new file mode 100644 index 00000000..215058c3 --- /dev/null +++ b/deploy/lsworker-entrypoint.sh @@ -0,0 +1,70 @@ +#!/bin/sh +set -eu + +PROXY_HOST="${LSWORKER_PROXY_HOST:-}" +PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}" +PROXY_USER="${LSWORKER_PROXY_USER:-}" +PROXY_PASS="${LSWORKER_PROXY_PASS:-}" +CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}" +REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}" +NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}" + +mkdir -p "$(dirname "${NETWORK_READY_FILE}")" + +if [ -z "${PROXY_HOST}" ]; then + echo "LSWORKER_PROXY_HOST is required" >&2 + exit 1 +fi + +PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')" +if [ -z "${PROXY_IP}" ]; then + echo "failed to resolve proxy host: ${PROXY_HOST}" >&2 + exit 1 +fi + +cat >/tmp/redsocks.conf <>/tmp/redsocks.conf +fi +if [ -n "${PROXY_PASS}" ]; then + printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf +fi + +cat >>/tmp/redsocks.conf </tmp/redsocks.log 2>&1 & +REDSOCKS_PID="$!" +trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT + +sleep 1 + +iptables -t nat -N REDSOCKS 2>/dev/null || true +iptables -t nat -F REDSOCKS +iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN +iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN +iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN +iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN +iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}" +iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true +iptables -t nat -A OUTPUT -p tcp -j REDSOCKS + +touch "${NETWORK_READY_FILE}" + +exec gosu sub2api /app/lsworker diff --git a/deploy/lsworker.Dockerfile b/deploy/lsworker.Dockerfile new file mode 100644 index 00000000..49178874 --- /dev/null +++ b/deploy/lsworker.Dockerfile @@ -0,0 +1,52 @@ +ARG GOLANG_IMAGE=golang:1.26.1-alpine +ARG DEBIAN_IMAGE=debian:bookworm-slim + +FROM ${GOLANG_IMAGE} AS builder + +WORKDIR /app/backend +RUN apk add --no-cache git ca-certificates tzdata + +COPY backend/go.mod backend/go.sum ./ +RUN go mod download + +COPY backend/ ./ +RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker + +FROM ${DEBIAN_IMAGE} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gosu \ + iproute2 \ + iptables \ + redsocks \ + tzdata \ + && rm -rf /var/lib/apt/lists/* + +RUN groupadd -g 1000 sub2api && \ + useradd -u 1000 -g sub2api -m -s /bin/sh sub2api + +WORKDIR /app + +COPY --from=builder /app/lsworker /app/lsworker +COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/ +COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/ + +ARG TARGETARCH +RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \ + if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \ + cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \ + else \ + cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \ + fi && \ + chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \ + chown -R sub2api:sub2api /app /run/lsworker && \ + rm -rf /tmp/ls-bin + +COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh +RUN chmod +x /app/lsworker-entrypoint.sh + +EXPOSE 18081 + +ENTRYPOINT ["/app/lsworker-entrypoint.sh"] diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index e731a7b1..dcaef15d 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -510,6 +510,7 @@ const handleEvent = (event: { addLine(streamingContent.value, 'text-green-300') streamingContent.value = '' } + addLine(`Error: ${errorMessage.value}`, 'text-red-400') break } } diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index e731a7b1..dcaef15d 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -510,6 +510,7 @@ const handleEvent = (event: { addLine(streamingContent.value, 'text-green-300') streamingContent.value = '' } + addLine(`Error: ${errorMessage.value}`, 'text-red-400') break } } diff --git a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts index 429a905c..801eab02 100644 --- a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts +++ b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts @@ -144,4 +144,28 @@ describe('AccountTestModal', () => { expect(preview.exists()).toBe(true) expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD') }) + + it('收到 error 事件时会把错误内容显示在终端输出里', async () => { + ;(global.fetch as any).mockResolvedValueOnce( + createStreamResponse([ + 'data: {"type":"test_start","model":"claude-opus-4-6"}\n', + 'data: {"type":"error","error":"API returned 429: You have exhausted your capacity on this model."}\n' + ]) + ) + + const wrapper = mountModal() + await wrapper.setProps({ show: true }) + await flushPromises() + + const buttons = wrapper.findAll('button') + const startButton = buttons.find((button) => button.text().includes('admin.accounts.startTest')) + expect(startButton).toBeTruthy() + + await startButton!.trigger('click') + await flushPromises() + await flushPromises() + + expect(wrapper.text()).toContain('API returned 429: You have exhausted your capacity on this model.') + expect(wrapper.text()).toContain('Error: API returned 429: You have exhausted your capacity on this model.') + }) })