feat: add dockerized antigravity ls worker mode

This commit is contained in:
win 2026-03-30 23:57:25 +08:00
parent 648e617f4e
commit 6694dcad14
42 changed files with 7753 additions and 28 deletions

View File

@ -9,6 +9,7 @@
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG DEBIAN_IMAGE=debian:bookworm-slim
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
@ -63,10 +64,12 @@ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
# Build the binary (BuildType=release for CI builds, embed frontend)
# Version precedence: build arg VERSION > cmd/server/VERSION
ARG TARGETARCH
ARG TARGETOS=linux
RUN VERSION_VALUE="${VERSION}" && \
if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \
DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \
CGO_ENABLED=0 GOOS=linux go build \
CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build \
-tags embed \
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
-trimpath \
@ -79,9 +82,9 @@ RUN VERSION_VALUE="${VERSION}" && \
FROM ${POSTGRES_IMAGE} AS pg-client
# -----------------------------------------------------------------------------
# Stage 4: Final Runtime Image
# Stage 4: Final Runtime Image (Debian for glibc — LS binary requires it)
# -----------------------------------------------------------------------------
FROM ${ALPINE_IMAGE}
FROM ${DEBIAN_IMAGE}
# Labels
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
@ -89,27 +92,25 @@ LABEL description="Sub2API - AI API Gateway Platform"
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
# Install runtime dependencies
RUN apk add --no-cache \
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
wget \
gosu \
proxychains4 \
tzdata \
su-exec \
libpq \
zstd-libs \
lz4-libs \
krb5-libs \
libldap \
libedit \
&& rm -rf /var/cache/apk/*
libpq5 \
&& rm -rf /var/lib/apt/lists/*
# Copy pg_dump and psql from the same postgres image used in docker-compose
# This ensures version consistency between backup tools and the database server
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
RUN ldconfig
# Create non-root user
RUN addgroup -g 1000 sub2api && \
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
RUN groupadd -g 1000 sub2api && \
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
# Set working directory
WORKDIR /app
@ -118,6 +119,21 @@ WORKDIR /app
COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api
COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources
# Copy Language Server binary and cert (for LS pool mode)
# Enable with: ANTIGRAVITY_LS_MODE=true ANTIGRAVITY_APP_ROOT=/app/ls
# TARGETARCH is set automatically by buildx (amd64 or arm64)
ARG TARGETARCH
COPY --chown=sub2api:sub2api deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
COPY --chown=sub2api:sub2api deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
RUN mkdir -p /app/ls/extensions/antigravity/bin && \
if [ "$TARGETARCH" = "arm64" ]; then \
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
else \
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
fi && \
chmod +x /app/ls/extensions/antigravity/bin/language_server_linux_* && \
rm -rf /tmp/ls-bin
# Create data directory
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data

View File

@ -0,0 +1,49 @@
package main
import (
"context"
"errors"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
)
func main() {
server, err := lspool.NewWorkerServerFromEnv()
if err != nil {
slog.Error("failed to initialize lsworker", "err", err)
os.Exit(1)
}
defer server.Close()
httpServer := &http.Server{
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
Handler: server.Handler(),
ReadHeaderTimeout: 10 * 1e9,
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
go func() {
<-ctx.Done()
_ = httpServer.Shutdown(context.Background())
}()
slog.Info("lsworker listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Error("lsworker exited with error", "err", err)
os.Exit(1)
}
}
func envOrDefault(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}

View File

@ -379,6 +379,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// AntigravityLSWorker: LS worker 容器控制平面配置
AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@ -481,6 +483,16 @@ type GatewayConfig struct {
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
}
type GatewayAntigravityLSWorkerConfig struct {
Image string `mapstructure:"image"`
Network string `mapstructure:"network"`
DockerSocket string `mapstructure:"docker_socket"`
IdleTTL time.Duration `mapstructure:"idle_ttl"`
MaxActive int `mapstructure:"max_active"`
StartupTimeout time.Duration `mapstructure:"startup_timeout"`
RequestTimeout time.Duration `mapstructure:"request_timeout"`
}
// UserMessageQueueConfig 用户消息串行队列配置
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
type UserMessageQueueConfig struct {
@ -1307,6 +1319,15 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
// Gateway LS worker
viper.SetDefault("gateway.antigravity_ls_worker.image", "weishaw/sub2api-lsworker:latest")
viper.SetDefault("gateway.antigravity_ls_worker.network", "sub2api-network")
viper.SetDefault("gateway.antigravity_ls_worker.docker_socket", "unix:///var/run/docker.sock")
viper.SetDefault("gateway.antigravity_ls_worker.idle_ttl", 15*time.Minute)
viper.SetDefault("gateway.antigravity_ls_worker.max_active", 50)
viper.SetDefault("gateway.antigravity_ls_worker.startup_timeout", 45*time.Second)
viper.SetDefault("gateway.antigravity_ls_worker.request_timeout", 60*time.Second)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移

View File

@ -440,7 +440,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
reqBody.Metadata.IDEVersion = "1.20.6"
reqBody.Metadata.IDEVersion = "1.107.0"
reqBody.Metadata.IDEName = "antigravity"
bodyBytes, err := json.Marshal(reqBody)

View File

@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strings"
"sync"
"time"
@ -49,8 +50,8 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
var defaultUserAgentVersion = "1.20.5"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
var defaultUserAgentVersion = "1.107.0"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
@ -66,9 +67,9 @@ func init() {
}
}
// GetUserAgent 返回当前配置的 User-Agent
// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为)
func GetUserAgent() string {
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH)
}
func getClientSecret() (string, error) {

View File

@ -39,6 +39,34 @@ func generateStableSessionID(contents []GeminiContent) string {
return "-" + strconv.FormatInt(n, 10)
}
// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it.
// preferredSessionID wins; otherwise we derive a stable value from the first user turn.
func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" {
return body, nil
}
sessionID := strings.TrimSpace(preferredSessionID)
if sessionID == "" {
var req GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
sessionID = generateStableSessionID(req.Contents)
}
if sessionID == "" {
return body, nil
}
payload["sessionId"] = sessionID
return json.Marshal(payload)
}
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;

View File

@ -8,6 +8,43 @@ import (
"github.com/stretchr/testify/require"
)
func TestEnsureGeminiRequestSessionID(t *testing.T) {
t.Run("prefers provided session id", func(t *testing.T) {
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
require.NoError(t, err)
var payload map[string]any
require.NoError(t, json.Unmarshal(updated, &payload))
require.Equal(t, "session-from-header", payload["sessionId"])
})
t.Run("keeps existing session id", func(t *testing.T) {
body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
require.NoError(t, err)
var payload map[string]any
require.NoError(t, json.Unmarshal(updated, &payload))
require.Equal(t, "session-in-body", payload["sessionId"])
})
t.Run("derives stable fallback from contents", func(t *testing.T) {
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
first, err := EnsureGeminiRequestSessionID(body, "")
require.NoError(t, err)
second, err := EnsureGeminiRequestSessionID(body, "")
require.NoError(t, err)
var firstPayload map[string]any
var secondPayload map[string]any
require.NoError(t, json.Unmarshal(first, &firstPayload))
require.NoError(t, json.Unmarshal(second, &secondPayload))
require.NotEmpty(t, firstPayload["sessionId"])
require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"])
})
}
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {

View File

@ -0,0 +1,13 @@
package lspool
import "time"
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
// It may be backed by a local in-process Pool or by remote LS workers.
type Backend interface {
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
Stats() map[string]any
Close()
}

View File

@ -0,0 +1,94 @@
// Package lspool provides LS-mode integration for the antigravity gateway.
//
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
// streamGenerateContent are routed through a real Language Server instance
// instead of directly to cloudcode-pa. This provides:
//
// - Authentic TLS fingerprint (Google's own Go binary)
// - Real session management and Heartbeat
// - Indistinguishable from a real IDE instance
//
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
package lspool
import (
"log/slog"
"os"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var (
globalBackend Backend
globalPoolOnce sync.Once
lsModeEnabled bool
)
func init() {
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
}
// IsLSModeEnabled returns whether LS mode is active
func IsLSModeEnabled() bool {
return lsModeEnabled
}
const (
LSStrategyDirect = "direct"
LSStrategyJSParity = "js-parity"
)
// CurrentLSStrategy returns the active LS routing strategy.
// Unknown values are treated as "direct" for safety.
func CurrentLSStrategy() string {
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
case "", LSStrategyDirect:
return LSStrategyDirect
case LSStrategyJSParity:
return LSStrategyJSParity
default:
return LSStrategyDirect
}
}
// GlobalPool returns the singleton LS pool instance
// Creates it on first call if LS mode is enabled
func GlobalPool(cfg *config.Config) Backend {
if !lsModeEnabled {
return nil
}
globalPoolOnce.Do(func() {
manager, err := NewWorkerManagerFromConfig(cfg)
if err != nil {
slog.Default().Error("failed to initialize LS worker manager", "err", err)
return
}
globalBackend = manager
})
return globalBackend
}
// Shutdown closes the global pool
func Shutdown() {
if globalBackend != nil {
globalBackend.Close()
}
}
// StatusInfo returns the current LS pool status for diagnostics
func StatusInfo() map[string]any {
info := map[string]any{
"ls_mode_enabled": lsModeEnabled,
"build": "enhanced",
"user_agent": "antigravity/1.107.0",
}
if lsModeEnabled && globalBackend != nil {
stats := globalBackend.Stats()
info["pool_total"] = stats["total"]
info["pool_active"] = stats["active"]
}
return info
}

View File

@ -0,0 +1,794 @@
package lspool
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func readConnectFrame(r io.Reader) ([]byte, error) {
header := make([]byte, 5)
if _, err := io.ReadFull(r, header); err != nil {
return nil, err
}
payloadLen := binary.BigEndian.Uint32(header[1:5])
payload := make([]byte, payloadLen)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
return payload, nil
}
func decodeProtoBytesField(data []byte, targetField int) []byte {
i := 0
for i < len(data) {
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0:
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
case 2:
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
if i+int(length) > len(data) {
return nil
}
if fieldNum == targetField {
return data[i : i+int(length)]
}
i += int(length)
case 1:
i += 8
case 5:
i += 4
default:
return nil
}
}
return nil
}
// TestMockExtensionServerTokenInjection verifies the token injection flow:
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
func TestMockExtensionServerTokenInjection(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
// 1. Set token for an account
srv.SetToken("account-1", &TokenInfo{
AccessToken: "ya29.test-access-token",
RefreshToken: "1//test-refresh-token",
ExpiresAt: time.Now().Add(1 * time.Hour),
})
// 2. Verify token is stored
srv.mu.RLock()
info, ok := srv.tokens["account-1"]
srv.mu.RUnlock()
require.True(t, ok)
require.Equal(t, "ya29.test-access-token", info.AccessToken)
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
// The stream handler will block, so run in background and cancel after we confirm connection
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
// Read the first envelope frame (initial state)
header := make([]byte, 5)
n, readErr := resp.Body.Read(header)
if readErr == nil && n == 5 {
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
}
}
}
// TestMockExtensionServerCSRF verifies CSRF token validation
func TestMockExtensionServerCSRF(t *testing.T) {
csrf := "correct-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
// Wrong CSRF → 403
req, _ := http.NewRequest("POST", base, nil)
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 403, resp.StatusCode)
// Correct CSRF → 200
req2, _ := http.NewRequest("POST", base, nil)
req2.Header.Set("x-codeium-csrf-token", csrf)
req2.Header.Set("Content-Type", "application/proto")
resp2, err := http.DefaultClient.Do(req2)
require.NoError(t, err)
defer resp2.Body.Close()
require.Equal(t, 200, resp2.StatusCode)
}
// TestMockExtensionServerGetSecretValue verifies the fallback token path
func TestMockExtensionServerGetSecretValue(t *testing.T) {
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
// GetSecretValue should return the token
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
nil)
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
}
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
func TestOAuthTokenInfoProto(t *testing.T) {
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
// Verify fields are present by checking proto wire format
require.True(t, len(bin) > 0, "proto should not be empty")
// Field 1 (access_token): tag=0x0a, value="ya29.test"
require.Contains(t, string(bin), "ya29.test")
// Field 2 (token_type): tag=0x12, value="Bearer"
require.Contains(t, string(bin), "Bearer")
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
require.Contains(t, string(bin), "1//refresh")
// Without refresh_token
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
require.NotContains(t, string(binNoRefresh), "1//refresh")
}
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
future := time.Now().Add(2 * time.Hour)
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
// Zero expiry should default to ~1h
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
// They should be different lengths or content (different expiry timestamps)
// Both should be valid (non-empty)
require.True(t, len(bin) > 0)
require.True(t, len(binZero) > 0)
}
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
func TestUSSTopicWithOAuth(t *testing.T) {
expiry := time.Now().Add(1 * time.Hour)
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
require.True(t, len(topic) > 0)
// The topic should contain the sentinel key
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
}
func TestUSSTopicWithModelCredits(t *testing.T) {
available := int32(123)
minimum := int32(50)
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
require.True(t, len(topic) > 0)
require.Contains(t, string(topic), useAICreditsSentinelKey)
require.Contains(t, string(topic), availableCreditsSentinelKey)
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
}
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
// Drain the initial_state frame first.
_, err = readConnectFrame(resp.Body)
require.NoError(t, err)
available := int32(77)
minimum := int32(25)
srv.SetModelCredits("account-1", &ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
keys := make([]string, 0, 3)
for len(keys) < 3 {
frame, readErr := readConnectFrame(resp.Body)
require.NoError(t, readErr)
applied := decodeProtoBytesField(frame, 2)
require.NotEmpty(t, applied)
keys = append(keys, decodeProtoString(applied, 1))
}
require.Contains(t, keys, useAICreditsSentinelKey)
require.Contains(t, keys, availableCreditsSentinelKey)
require.Contains(t, keys, minimumCreditAmountForUsageKey)
}
// TestBuildInitialStateUpdate verifies the USS update wrapper
func TestBuildInitialStateUpdate(t *testing.T) {
topicData := buildEmptyTopic()
update := buildInitialStateUpdate(topicData)
// Should be a valid proto bytes field (field 1 = initial_state)
require.True(t, len(update) >= 0) // empty topic is valid
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
update2 := buildInitialStateUpdate(topicData2)
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
}
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
func TestPoolSetAccountTokenComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
expiry := time.Now().Add(1 * time.Hour)
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
srv.mu.RLock()
info := srv.tokens["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.Equal(t, "ya29.full-token", info.AccessToken)
require.Equal(t, "1//full-refresh", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
}
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
available := int32(77)
minimum := int32(25)
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
srv.mu.RLock()
info := srv.credits["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.True(t, info.UseAICredits)
require.NotNil(t, info.AvailableCredits)
require.Equal(t, available, *info.AvailableCredits)
require.NotNil(t, info.MinimumCreditAmountForUsage)
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
}
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
// Create a mock upstream that records what it receives
var receivedHeaders http.Header
var mu sync.Mutex
fallback := &recordingUpstreamWithCallback{}
fallback.onDo = func(req *http.Request) {
mu.Lock()
receivedHeaders = req.Header.Clone()
mu.Unlock()
}
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
}
upstream := NewLSPoolUpstream(pool, fallback)
// Non-streamGenerateContent request → should pass through to fallback
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
req.Header.Set("Authorization", "Bearer ya29.test")
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
req.Header.Set(useAICreditsHeader, "true")
req.Header.Set(availableCreditsHeader, "42")
req.Header.Set(minimumCreditAmountHeader, "50")
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
// Internal headers should never leak to the direct upstream.
mu.Lock()
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
mu.Unlock()
srv.mu.RLock()
tokenInfo := srv.tokens["1"]
creditsInfo := srv.credits["1"]
srv.mu.RUnlock()
require.NotNil(t, tokenInfo)
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
require.NotNil(t, creditsInfo)
require.True(t, creditsInfo.UseAICredits)
require.NotNil(t, creditsInfo.AvailableCredits)
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
}
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
body := `{
"model": "claude-sonnet-4-6",
"request": {
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}
}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "claude-sonnet-4-6", model)
require.Contains(t, prompt, "You are helpful")
require.Contains(t, prompt, "Hello")
require.Contains(t, prompt, "Hi there!")
require.Contains(t, prompt, "How are you?")
}
// TestExtractUsageFromTrajectory verifies token usage extraction
func TestExtractUsageFromTrajectory(t *testing.T) {
resp := `{
"trajectory": {
"steps": [{
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
"status": "CORTEX_STEP_STATUS_DONE",
"plannerResponse": {"response": "OK"},
"metadata": {
"modelUsage": {
"inputTokens": "150",
"outputTokens": "5"
}
}
}]
}
}`
usage := extractUsageFromTrajectory([]byte(resp))
require.NotNil(t, usage)
require.Equal(t, 150, usage["promptTokenCount"])
require.Equal(t, 5, usage["candidatesTokenCount"])
require.Equal(t, 155, usage["totalTokenCount"])
}
// TestSSEChunkFormat verifies the Gemini SSE output format
func TestSSEChunkFormat(t *testing.T) {
chunk := buildGeminiSSEChunk("Hello world")
require.True(t, len(chunk) > 0)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"text":"Hello world"`)
require.Contains(t, chunk, `"role":"model"`)
require.True(t, chunk[len(chunk)-2:] == "\n\n")
// Verify it's valid JSON after stripping "data: " prefix
jsonStr := chunk[len("data: ") : len(chunk)-2]
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err)
response := parsed["response"].(map[string]any)
candidates := response["candidates"].([]any)
require.Len(t, candidates, 1)
}
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
func TestSSEFinalChunkFormat(t *testing.T) {
usage := map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150,
}
chunk := buildGeminiSSEFinalChunk(usage)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"finishReason":"STOP"`)
require.Contains(t, chunk, `"usageMetadata"`)
}
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
var getCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
return
}
http.NotFound(w, r)
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
pr, pw := io.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
close(done)
}()
body, err := io.ReadAll(pr)
require.NoError(t, err)
<-done
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
require.Contains(t, string(body), "hello from ls")
}
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
func TestRequestHasToolsEdgeCases(t *testing.T) {
// null tools
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
// tools with empty function declarations
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
// deeply nested wrapped format
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
}
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
var getCalls atomic.Int32
var sendBodiesMu sync.Mutex
var sendBodies []map[string]any
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
var payload map[string]any
err := json.NewDecoder(r.Body).Decode(&payload)
require.NoError(t, err)
sendBodiesMu.Lock()
sendBodies = append(sendBodies, payload)
sendBodiesMu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
call := getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
text := "hello from ls"
if call > 1 {
text = "follow up from ls"
}
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Contains(t, string(body2), `"text":"follow up from ls"`)
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
require.Equal(t, int32(2), sendCalls.Load())
sendBodiesMu.Lock()
require.Len(t, sendBodies, 2)
firstSend := sendBodies[0]
sendBodiesMu.Unlock()
require.Equal(t, "cid-1", firstSend["cascadeId"])
require.Equal(t, false, firstSend["blocking"])
metadata, ok := firstSend["metadata"].(map[string]any)
require.True(t, ok)
require.Equal(t, "antigravity", metadata["ideName"])
require.Equal(t, "1.107.0", metadata["ideVersion"])
require.NotContains(t, firstSend, "clientType")
require.NotContains(t, firstSend, "messageOrigin")
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
require.True(t, ok)
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
require.Len(t, cascadeConfig, 1)
}
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "ok", string(body2))
require.Equal(t, 1, fallback.doCalls)
require.Equal(t, int32(1), startCalls.Load())
require.Equal(t, int32(1), sendCalls.Load())
}
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer downstream-a")
resp, err := upstream.Do(req, "", 42, 1)
require.Nil(t, resp)
require.ErrorIs(t, err, errLSModelMapPending)
require.Equal(t, int32(0), startCalls.Load())
require.Equal(t, int32(0), sendCalls.Load())
require.Equal(t, 0, fallback.doCalls)
}
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
type recordingUpstreamWithCallback struct {
recordingUpstream
onDo func(req *http.Request)
}
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
if r.onDo != nil {
r.onDo(req)
}
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
}

View File

@ -0,0 +1,908 @@
// Package lspool provides a mock Extension Server that the LS binary connects
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
// using connectNodeAdapter. We replicate that protocol here.
//
// Protocol details (from extension.js source):
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
// - Auth: x-codeium-csrf-token header on every request
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
// OR application/connect+proto (with 5-byte envelope)
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
//
// The LS sends requests with content-type "application/connect+proto" for BOTH
// unary and streaming RPCs. ConnectRPC's content-type regex:
//
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
//
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
// However the LS Go client uses the Connect protocol client which always sends
// "application/proto" for unary and "application/connect+proto" for streaming.
//
// We detect the RPC kind from the URL path and respond accordingly.
package lspool
import (
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
// ============================================================
// Proto helpers — hand-encode minimal proto messages so we don't
// need to import the full generated proto package.
// ============================================================
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
func encodeProtoString(fieldNum int, val string) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, []byte(val)...)
return out
}
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
func encodeProtoBytes(fieldNum int, val []byte) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, val...)
return out
}
// encodeProtoVarint writes a proto varint field (wire type 0).
func encodeProtoVarint(fieldNum int, val uint64) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 0))
v := encodeVarint(val)
out := make([]byte, 0, len(tag)+len(v))
out = append(out, tag...)
out = append(out, v...)
return out
}
// encodeProtoBool writes a proto bool field.
func encodeProtoBool(fieldNum int, val bool) []byte {
v := uint64(0)
if val {
v = 1
}
return encodeProtoVarint(fieldNum, v)
}
func encodeVarint(v uint64) []byte {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutUvarint(buf, v)
return buf[:n]
}
// decodeProtoString extracts a string field from raw proto bytes.
func decodeProtoString(data []byte, targetField int) string {
i := 0
for i < len(data) {
if i >= len(data) {
break
}
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
break
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0: // varint
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
case 2: // length-delimited
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
if fieldNum == targetField {
end := i + int(length)
if end > len(data) {
return ""
}
return string(data[i:end])
}
i += int(length)
case 1: // 64-bit
i += 8
case 5: // 32-bit
i += 4
default:
return ""
}
}
return ""
}
// ============================================================
// ConnectRPC envelope helpers
// ============================================================
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
// 1 byte flags + 4 byte big-endian length + payload
func connectEnvelope(flags byte, payload []byte) []byte {
frame := make([]byte, 5+len(payload))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
copy(frame[5:], payload)
return frame
}
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
func connectEndOfStream() []byte {
trailer := []byte("{}")
return connectEnvelope(0x02, trailer)
}
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
func unwrapConnectEnvelope(body []byte) []byte {
if len(body) < 5 {
return body
}
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
if body[0] > 0x02 {
return body // Not envelope-framed, return raw
}
plen := binary.BigEndian.Uint32(body[1:5])
if int(plen)+5 > len(body) {
return body // Length mismatch, return raw
}
return body[5 : 5+plen]
}
// ============================================================
// OAuthTokenInfo proto builder
// ============================================================
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
//
// message OAuthTokenInfo {
// string access_token = 1;
// string token_type = 2;
// string refresh_token = 3;
// google.protobuf.Timestamp expiry = 4;
// bool is_gcp_tos = 6;
// }
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
var buf []byte
buf = append(buf, encodeProtoString(1, accessToken)...)
buf = append(buf, encodeProtoString(2, "Bearer")...)
if refreshToken != "" {
buf = append(buf, encodeProtoString(3, refreshToken)...)
}
// Use real expiry if provided, otherwise default to 1 hour from now
expiry := expiresAt
if expiry.IsZero() {
expiry = time.Now().Add(1 * time.Hour)
}
ts := &timestamppb.Timestamp{
Seconds: expiry.Unix(),
}
tsBytes, _ := proto.Marshal(ts)
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
buf = append(buf, encodeProtoBool(6, true)...)
return buf
}
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
//
// message Topic { map<string, Row> data = 1; }
// message Row { string value = 1; int64 e_tag = 2; }
//
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
// base64(toBinary(OAuthTokenInfo)).
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
var row []byte
row = append(row, encodeProtoString(1, tokenB64)...)
row = append(row, encodeProtoVarint(2, 1)...)
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
var entry []byte
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
entry = append(entry, encodeProtoBytes(2, row)...)
// Topic: data map entries use field 1
var topic []byte
topic = append(topic, encodeProtoBytes(1, entry)...)
return topic
}
func buildPrimitiveBoolBinary(val bool) []byte {
// Primitive.bool_value is field 13 in the proto definition
return encodeProtoBool(13, val)
}
func buildPrimitiveInt32Binary(val int32) []byte {
// Primitive.int32_value is field 3 in the proto definition
return encodeProtoVarint(3, uint64(uint32(val)))
}
func buildUSSTopicRow(key string, value string) []byte {
row := buildUSSRowBinary(value)
var entry []byte
entry = append(entry, encodeProtoString(1, key)...)
entry = append(entry, encodeProtoBytes(2, row)...)
return entry
}
func buildUSSRowBinary(value string) []byte {
var row []byte
row = append(row, encodeProtoString(1, value)...)
row = append(row, encodeProtoVarint(2, 1)...)
return row
}
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
entries := make([][]byte, 0, 3)
entries = append(entries, buildUSSTopicRow(
useAICreditsSentinelKey,
string(buildPrimitiveBoolBinary(info.UseAICredits)),
))
// JS protocol: useAICreditsSentinelKey carries the toggle state.
// availableCreditsSentinelKey is only present when credits are enabled.
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, string(buildPrimitiveInt32Binary(credits))))
}
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, string(buildPrimitiveInt32Binary(minimum))))
var topic []byte
for _, entry := range entries {
topic = append(topic, encodeProtoBytes(1, entry)...)
}
return topic
}
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
func buildEmptyTopic() []byte {
return []byte{} // Empty message = no map entries
}
// ============================================================
// UnifiedStateSyncUpdate builder
// ============================================================
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
//
// message UnifiedStateSyncUpdate {
// oneof update_type {
// Topic initial_state = 1;
// AppliedUpdate applied_update = 2;
// }
// }
func buildInitialStateUpdate(topicData []byte) []byte {
return encodeProtoBytes(1, topicData)
}
func buildAppliedUpdate(key string, row []byte) []byte {
var applied []byte
applied = append(applied, encodeProtoString(1, key)...)
if len(row) > 0 {
applied = append(applied, encodeProtoBytes(2, row)...)
}
return encodeProtoBytes(2, applied)
}
// ============================================================
// MockExtensionServer
// ============================================================
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
// Language Server binary connects to. It implements just enough of the
// ExtensionServerService to keep the LS operational.
type MockExtensionServer struct {
listener net.Listener
server *http.Server
port int
csrf string
mu sync.RWMutex
tokens map[string]*TokenInfo // account_id -> token info
credits map[string]*ModelCreditsInfo // account_id -> model credits info
subscribers map[string]map[int]*stateSubscriber
nextSubID int
lastAccountID string
logger *slog.Logger
// Trajectory callback — when LS pushes trajectory updates, we forward them
onTrajectoryUpdate func(topic, key string, data []byte)
}
// TokenInfo holds OAuth token details for an account.
type TokenInfo struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
}
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
type ModelCreditsInfo struct {
UseAICredits bool
AvailableCredits *int32
MinimumCreditAmountForUsage *int32
}
type stateSubscriber struct {
id int
accountID string
topic string
updates chan []byte
}
const (
useAICreditsSentinelKey = "useAICreditsSentinelKey"
availableCreditsSentinelKey = "availableCreditsSentinelKey"
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
defaultMinimumCreditAmountForUsage = int32(50)
)
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
m := &MockExtensionServer{
listener: listener,
port: listener.Addr().(*net.TCPAddr).Port,
csrf: csrf,
tokens: make(map[string]*TokenInfo),
credits: make(map[string]*ModelCreditsInfo),
subscribers: make(map[string]map[int]*stateSubscriber),
logger: slog.Default().With("component", "mock-ext-server"),
}
mux := http.NewServeMux()
extService := "/exa.extension_server_pb.ExtensionServerService/"
// Register all RPCs the LS calls on the Extension Server.
// Unary RPCs — return application/proto
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
// Server-streaming RPCs — return application/connect+proto
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
// Catch-all for any unregistered RPCs
mux.HandleFunc("/", m.handleCatchAll)
m.server = &http.Server{Handler: mux}
go func() {
if err := m.server.Serve(listener); err != http.ErrServerClosed {
m.logger.Error("extension server error", "err", err)
}
}()
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
return m, nil
}
// Port returns the listening port.
func (m *MockExtensionServer) Port() int {
return m.port
}
// SetToken sets the OAuth token for an account.
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
m.mu.Lock()
m.tokens[accountID] = info
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
m.mu.Unlock()
if info == nil {
return
}
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
}
// SetModelCredits sets the uss-modelCredits state for an account.
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
if info == nil {
info = &ModelCreditsInfo{}
}
copyInfo := *info
m.mu.Lock()
m.credits[accountID] = &copyInfo
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
m.mu.Unlock()
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(&copyInfo)...)
}
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
m.onTrajectoryUpdate = fn
}
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
if m.lastAccountID != "" {
if info := m.tokens[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.tokens {
return info
}
return nil
}
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
if m.lastAccountID != "" {
if info := m.credits[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.credits {
return info
}
return nil
}
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
if accountID != "" {
if info := m.tokens[accountID]; info != nil {
return info
}
}
return m.currentTokenLocked()
}
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
if accountID != "" {
if info := m.credits[accountID]; info != nil {
return info
}
}
return m.currentModelCreditsLocked()
}
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
topicSubs := m.subscribers[topic]
if len(topicSubs) == 0 {
return nil
}
out := make([]*stateSubscriber, 0, len(topicSubs))
for _, sub := range topicSubs {
if sub == nil {
continue
}
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
continue
}
out = append(out, sub)
}
return out
}
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
for _, sub := range subscribers {
if sub == nil {
continue
}
for _, update := range updates {
if len(update) == 0 {
continue
}
payload := append([]byte(nil), update...)
select {
case sub.updates <- payload:
default:
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
}
}
}
}
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
updates := make([][]byte, 0, 3)
updates = append(updates, buildAppliedUpdate(
useAICreditsSentinelKey,
buildUSSRowBinary(string(buildPrimitiveBoolBinary(info.UseAICredits))),
))
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
updates = append(updates, buildAppliedUpdate(
availableCreditsSentinelKey,
buildUSSRowBinary(string(buildPrimitiveInt32Binary(credits))),
))
} else {
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
}
updates = append(updates, buildAppliedUpdate(
minimumCreditAmountForUsageKey,
buildUSSRowBinary(string(buildPrimitiveInt32Binary(minimum))),
))
return updates
}
// Close shuts down the server.
func (m *MockExtensionServer) Close() error {
return m.server.Close()
}
// ============================================================
// Middleware
// ============================================================
type unaryHandler func(body []byte) []byte
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// CSRF check
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
return
}
// The LS might send with envelope framing (application/connect+proto)
// or without (application/proto). Detect and unwrap.
ct := r.Header.Get("Content-Type")
protoBody := body
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
protoBody = unwrapConnectEnvelope(body)
}
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
responseProto := handler(protoBody)
// Respond with proper unary ConnectRPC content-type.
// If the request used "connect+proto", the response should be "application/proto"
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
if len(responseProto) > 0 {
w.Write(responseProto)
}
}
}
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
return
}
// Unwrap envelope from request
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
body = unwrapConnectEnvelope(body)
}
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
// Set streaming response content-type
w.Header().Set("Content-Type", "application/connect+proto")
w.WriteHeader(200)
handler(body, w, r)
}
}
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
token := r.Header.Get("x-codeium-csrf-token")
if m.csrf != "" && token != m.csrf {
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(403)
w.Write([]byte("Invalid CSRF token"))
return false
}
return true
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// ============================================================
// Unary RPC Handlers — each receives raw proto request body,
// returns raw proto response body.
// ============================================================
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
// We just log the ports — they're informational.
m.logger.Info("LanguageServerStarted",
"body_len", len(body))
// Return empty LanguageServerStartedResponse
return nil
}
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
// Return empty HeartbeatResponse
return nil
}
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
// GetSecretValueRequest: key = field 1
key := decodeProtoString(body, 1)
m.logger.Debug("GetSecretValue", "key", key)
m.mu.RLock()
var token string
if info := m.currentTokenLocked(); info != nil {
token = info.AccessToken
}
m.mu.RUnlock()
// GetSecretValueResponse: value = field 1
if token != "" {
return encodeProtoString(1, token)
}
return nil
}
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
key := decodeProtoString(body, 1)
m.logger.Debug("StoreSecretValue", "key", key)
return nil
}
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
return encodeProtoBool(1, false)
}
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
// Extract topic name from the embedded UpdateRequest
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
// We need to dig into the nested message to get topic_name
if m.onTrajectoryUpdate != nil {
// For now, just notify that an update was pushed
m.onTrajectoryUpdate("", "", body)
}
// Return empty PushUnifiedStateSyncUpdateResponse
return nil
}
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
m.logger.Debug("RecordError", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
return nil
}
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onDefault(body []byte) []byte {
return nil
}
// ============================================================
// Streaming RPC Handlers
// ============================================================
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
topic := decodeProtoString(body, 1)
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
flusher, ok := w.(http.Flusher)
if !ok {
m.logger.Error("ResponseWriter does not support Flush")
return
}
m.mu.Lock()
accountID := m.lastAccountID
subID := m.nextSubID
m.nextSubID++
sub := &stateSubscriber{
id: subID,
accountID: accountID,
topic: topic,
updates: make(chan []byte, 16),
}
if m.subscribers[topic] == nil {
m.subscribers[topic] = make(map[int]*stateSubscriber)
}
m.subscribers[topic][subID] = sub
// Build initial state based on topic
var topicData []byte
switch topic {
case "uss-oauth":
tokenInfo := m.tokenForAccountLocked(accountID)
if tokenInfo != nil {
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
} else {
topicData = buildEmptyTopic()
}
case "uss-modelCredits":
creditsInfo := m.creditsForAccountLocked(accountID)
if creditsInfo != nil {
topicData = buildUSSTopicWithModelCredits(creditsInfo)
} else {
topicData = buildEmptyTopic()
}
default:
// For all other topics (browserPreferences, enterprisePreferences, etc.),
// return empty topic data.
topicData = buildEmptyTopic()
}
m.mu.Unlock()
defer func() {
m.mu.Lock()
if topicSubs := m.subscribers[topic]; topicSubs != nil {
delete(topicSubs, subID)
if len(topicSubs) == 0 {
delete(m.subscribers, topic)
}
}
m.mu.Unlock()
}()
// Send initial state as envelope-framed message
initialUpdate := buildInitialStateUpdate(topicData)
frame := connectEnvelope(0x00, initialUpdate)
w.Write(frame)
flusher.Flush()
for {
select {
case <-r.Context().Done():
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
return
case update := <-sub.updates:
if len(update) == 0 {
continue
}
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
return
}
flusher.Flush()
}
}
}
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
// Send end-of-stream immediately — we don't execute commands
flusher, ok := w.(http.Flusher)
if !ok {
return
}
w.Write(connectEndOfStream())
flusher.Flush()
}
// ============================================================
// Catch-all handler
// ============================================================
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
// Drain request body
io.ReadAll(r.Body)
// Determine if this is likely a unary or streaming request based on content-type.
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+") {
// Could be streaming — respond with unary proto to be safe
// (unary Connect requests can also use connect+ prefix in some client impls)
w.Header().Set("Content-Type", "application/proto")
} else {
w.Header().Set("Content-Type", "application/proto")
}
w.WriteHeader(200)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,346 @@
package lspool
import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
env := buildLSEnv([]string{
"SSL_CERT_FILE=/custom/ca.pem",
"SSL_CERT_DIR=/custom/certs",
}, "/opt/antigravity", "")
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
}
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
env := buildLSEnv([]string{
"HTTPS_PROXY=http://old-proxy:8080",
"HTTP_PROXY=http://old-proxy:8080",
"ALL_PROXY=socks5://old-proxy:1080",
"https_proxy=http://old-proxy:8080",
"http_proxy=http://old-proxy:8080",
"all_proxy=socks5://old-proxy:1080",
}, "/opt/antigravity", "")
require.Contains(t, env, "HTTPS_PROXY=")
require.Contains(t, env, "HTTP_PROXY=")
require.Contains(t, env, "ALL_PROXY=")
require.Contains(t, env, "https_proxy=")
require.Contains(t, env, "http_proxy=")
require.Contains(t, env, "all_proxy=")
}
func TestShortAccountID(t *testing.T) {
require.Equal(t, "9", shortAccountID("9"))
require.Equal(t, "12345678", shortAccountID("12345678"))
require.Equal(t, "12345678", shortAccountID("123456789"))
}
func TestFrameConnectMessage(t *testing.T) {
framed := frameConnectMessage([]byte(`{"x":1}`))
require.Len(t, framed, 5+len(`{"x":1}`))
require.Equal(t, byte(0), framed[0])
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
require.Equal(t, `{"x":1}`, string(framed[5:]))
}
func TestConnectEnvelope(t *testing.T) {
payload := []byte("hello")
env := connectEnvelope(0x00, payload)
require.Len(t, env, 5+len(payload))
require.Equal(t, byte(0x00), env[0])
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
require.Equal(t, "hello", string(env[5:]))
}
func TestUnwrapConnectEnvelope(t *testing.T) {
payload := []byte("test data")
env := connectEnvelope(0x00, payload)
unwrapped := unwrapConnectEnvelope(env)
require.Equal(t, payload, unwrapped)
short := []byte{1, 2}
require.Equal(t, short, unwrapConnectEnvelope(short))
}
func TestExtractPromptAndModel(t *testing.T) {
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "hello world", prompt)
require.Equal(t, "gemini-2.5-pro", model)
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
prompt2, _ := extractPromptAndModel([]byte(body2))
require.Equal(t, "test prompt", prompt2)
}
func TestResolveModelEnum(t *testing.T) {
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
require.True(t, resolveModelEnum("unknown-model") > 0)
}
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("models/gemini-2.5-flash")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("claude-sonnet-4-6")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
fallback := &recordingUpstream{}
upstream := NewLSPoolUpstream(&Pool{}, fallback)
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, 1, fallback.doCalls)
}
func TestExtractPlannerResponseText(t *testing.T) {
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
"plannerResponse":{"response":"Hello world"}}
]}}`
text, generating, status := extractPlannerResponseText([]byte(resp))
require.Equal(t, "Hello world", text)
require.False(t, generating)
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
}
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
resp := `{
"status":"CASCADE_RUN_STATUS_IDLE",
"trajectory":{
"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
],
"executorMetadata":{
"terminationReason":"ERROR",
"errorDetails":{
"errorCode":429,
"shortError":"Model quota reached",
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
}
}
}
}`
state := extractPlannerResponseState([]byte(resp))
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
require.False(t, state.Generating)
require.Empty(t, state.Text)
require.Contains(t, state.ErrorMessage, "Model quota reached")
require.Contains(t, state.ErrorMessage, "quota will reset after")
}
func TestBuildGeminiSSEChunk(t *testing.T) {
sse := buildGeminiSSEChunk("hello")
require.Contains(t, sse, "data: ")
require.Contains(t, sse, `"text":"hello"`)
require.Contains(t, sse, `"role":"model"`)
require.True(t, strings.HasSuffix(sse, "\n\n"))
}
func TestRequestHasTools(t *testing.T) {
// Wrapped format with tools
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
// Direct format with tools
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
// No tools
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
// Empty tools array
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
}
func TestCurrentLSStrategy(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
}
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
require.Equal(t, 5, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
require.Equal(t, 3, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
require.Equal(t, 5, parseLSReplicaCount())
}
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {
{AccountID: "acc-1", Replica: 0, healthy: true},
{AccountID: "acc-1", Replica: 1, healthy: true},
{AccountID: "acc-1", Replica: 2, healthy: true},
{AccountID: "acc-1", Replica: 3, healthy: true},
{AccountID: "acc-1", Replica: 4, healthy: true},
},
},
}
routingKey := "acc-1:user-a:session-1"
slot := replicaSlotIndex(routingKey, pool.replicaCount())
inst := pool.Get("acc-1", routingKey)
require.NotNil(t, inst)
require.Equal(t, slot, inst.Replica)
}
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
atomic.StoreInt64(&busy.inflight, 4)
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
atomic.StoreInt64(&idle.inflight, 1)
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {busy, idle},
},
}
inst := pool.Get("acc-1", "")
require.NotNil(t, inst)
require.Equal(t, 1, inst.Replica)
}
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
return nil
})
require.NoError(t, err)
require.Equal(t, 1, attempts)
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
}
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
calls := 0
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
calls++
if calls < 3 {
return errors.New("not ready")
}
return nil
})
require.NoError(t, err)
require.Equal(t, 3, attempts)
require.Equal(t, 3, calls)
}
func TestDecideJSParityRoute(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsed, err := parseGeminiRequest(body)
require.NoError(t, err)
decision := decideJSParityRoute(parsed, body)
require.True(t, decision.UseLS)
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
parsedImage, err := parseGeminiRequest(imageBody)
require.NoError(t, err)
decisionImage := decideJSParityRoute(parsedImage, imageBody)
require.False(t, decisionImage.UseLS)
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsedNoSession, err := parseGeminiRequest(noSessionBody)
require.NoError(t, err)
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
}
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
require.NoError(t, err)
req.Header.Set(userNamespaceHeader, "tenant-a")
req.Header.Set("Authorization", "Bearer oauth-token")
nsWithExplicit := userNamespace(req)
require.NotEqual(t, "anon", nsWithExplicit)
req.Header.Del(userNamespaceHeader)
nsWithAuth := userNamespace(req)
require.NotEqual(t, "anon", nsWithAuth)
require.NotEqual(t, nsWithExplicit, nsWithAuth)
}
func TestConversationPrefixEqual(t *testing.T) {
prefix := []geminiConversationTurn{
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
}
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
Role: "user",
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
})
require.True(t, conversationPrefixEqual(full, prefix))
require.False(t, conversationPrefixEqual(prefix, full))
}
func TestSystemTextCompatible(t *testing.T) {
require.True(t, systemTextCompatible("You are helpful", ""))
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
require.False(t, systemTextCompatible("", "You are helpful"))
require.False(t, systemTextCompatible("You are helpful", "You are different"))
}
type recordingUpstream struct {
doCalls int
}
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
r.doCalls++
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
}
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, c)
}

View File

@ -0,0 +1,268 @@
package lspool
import (
"bufio"
"context"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"golang.org/x/net/proxy"
)
type lsProxyBridge struct {
listener net.Listener
server *http.Server
url string
upstream string
}
type lsProxyBridgeManager struct {
mu sync.Mutex
bridges map[string]*lsProxyBridge
logger *slog.Logger
}
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
bridges: make(map[string]*lsProxyBridge),
logger: slog.Default().With("component", "lspool-proxy-bridge"),
}
var (
lsProxyBridgeDialTimeout = 10 * time.Second
lsProxyBridgeProbeTargets = []string{
"cloudcode-pa.googleapis.com:443",
"oauthaccountmanager.googleapis.com:443",
}
)
func prepareLSProxyURL(raw string) (string, error) {
normalized, parsed, err := proxyurl.Parse(raw)
if err != nil {
return "", err
}
if parsed == nil {
return "", nil
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
return normalized, nil
case "socks5", "socks5h":
return globalLSProxyBridgeManager.ensure(normalized, parsed)
default:
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
}
}
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if bridge := m.bridges[key]; bridge != nil {
return bridge.url, nil
}
bridge, err := newLSProxyBridge(upstream, m.logger)
if err != nil {
return "", err
}
m.bridges[key] = bridge
return bridge.url, nil
}
func (m *lsProxyBridgeManager) closeAll() {
m.mu.Lock()
defer m.mu.Unlock()
for key, bridge := range m.bridges {
if bridge != nil {
_ = bridge.server.Close()
_ = bridge.listener.Close()
}
delete(m.bridges, key)
}
}
func closeAllLSProxyBridgesForTest() {
globalLSProxyBridgeManager.closeAll()
}
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
dialer, err := proxy.FromURL(upstream, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
}
bridge := &lsProxyBridge{
listener: listener,
url: "http://" + listener.Addr().String(),
upstream: upstream.Redacted(),
}
server := &http.Server{
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
ReadHeaderTimeout: 10 * time.Second,
IdleTimeout: 2 * time.Minute,
}
bridge.server = server
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
}
}()
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
go bridge.probeConnectivity(dialer, logger)
return bridge, nil
}
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
return
}
targetAddr := strings.TrimSpace(r.Host)
if targetAddr == "" {
targetAddr = strings.TrimSpace(r.URL.Host)
}
if targetAddr == "" {
http.Error(w, "missing target host", http.StatusBadRequest)
return
}
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
targetAddr = net.JoinHostPort(targetAddr, "443")
}
startedAt := time.Now()
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
defer cancel()
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
if err != nil {
logger.Warn("LS proxy bridge dial failed",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
"err", err)
http.Error(w, "proxy dial failed", http.StatusBadGateway)
return
}
logger.Info("LS proxy bridge CONNECT established",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
hijacker, ok := w.(http.Hijacker)
if !ok {
_ = targetConn.Close()
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
return
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
_ = targetConn.Close()
return
}
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
_ = targetConn.Close()
_ = clientConn.Close()
return
}
if rw != nil && rw.Reader.Buffered() > 0 {
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
_ = targetConn.Close()
_ = clientConn.Close()
return
}
}
tunnelConns(clientConn, targetConn)
}
}
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
return contextDialer.DialContext(ctx, "tcp", targetAddr)
}
type dialResult struct {
conn net.Conn
err error
}
resultCh := make(chan dialResult, 1)
go func() {
conn, err := dialer.Dial("tcp", targetAddr)
resultCh <- dialResult{conn: conn, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-resultCh:
return result.conn, result.err
}
}
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
for _, targetAddr := range lsProxyBridgeProbeTargets {
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
conn, err := dialViaProxy(ctx, dialer, targetAddr)
cancel()
if err != nil {
logger.Warn("LS proxy bridge probe failed",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
"err", err)
continue
}
_ = conn.Close()
logger.Info("LS proxy bridge probe succeeded",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
}
}
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
var once sync.Once
closeBoth := func() {
_ = clientConn.Close()
_ = targetConn.Close()
}
go func() {
_, _ = io.Copy(targetConn, clientConn)
once.Do(closeBoth)
}()
go func() {
_, _ = io.Copy(clientConn, targetConn)
once.Do(closeBoth)
}()
}
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
}

View File

@ -0,0 +1,193 @@
package lspool
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
t.Cleanup(closeAllLSProxyBridgesForTest)
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
require.NoError(t, err)
require.Equal(t, "http://proxy.example.com:8080", got)
}
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
t.Cleanup(closeAllLSProxyBridgesForTest)
targetAddr, closeTarget := startBridgeEchoServer(t)
defer closeTarget()
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
defer closeSOCKS()
bridgeURL, err := prepareLSProxyURL(socksURL)
require.NoError(t, err)
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
// Same SOCKS upstream should reuse the same local bridge.
reusedURL, err := prepareLSProxyURL(socksURL)
require.NoError(t, err)
require.Equal(t, bridgeURL, reusedURL)
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
conn, err := net.Dial("tcp", bridgeAddr)
require.NoError(t, err)
defer conn.Close()
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
require.NoError(t, err)
reader := bufio.NewReader(conn)
resp, err := readConnectResponse(reader)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
reply := make([]byte, 4)
_, err = io.ReadFull(reader, reply)
require.NoError(t, err)
require.Equal(t, "pong", string(reply))
}
func startBridgeEchoServer(t *testing.T) (string, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 4)
if _, err := io.ReadFull(c, buf); err != nil {
return
}
if string(buf) == "ping" {
_, _ = c.Write([]byte("pong"))
}
}(conn)
}
}()
return ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go handleBridgeSOCKS5Conn(conn)
}
}()
return "socks5://" + ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func handleBridgeSOCKS5Conn(conn net.Conn) {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
_ = conn.Close()
return
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
_ = conn.Close()
return
}
_, _ = conn.Write([]byte{0x05, 0x00})
reqHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, reqHeader); err != nil {
_ = conn.Close()
return
}
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
_ = conn.Close()
return
}
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
if !ok {
_ = conn.Close()
return
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
_ = conn.Close()
return
}
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
_ = conn.Close()
return
}
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
tunnelConns(conn, targetConn)
}
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
switch atyp {
case 0x01:
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return net.IP(buf).String(), true
case 0x03:
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", false
}
buf := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return string(buf), true
case 0x04:
buf := make([]byte, 16)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return net.IP(buf).String(), true
default:
return "", false
}
}

View File

@ -0,0 +1,138 @@
package lspool
import (
"fmt"
"net/url"
"os"
"os/exec"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
)
type lsLaunchPlan struct {
cmd *exec.Cmd
effectiveProxyURL string
proxyMode string
cleanup func()
}
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
plan := &lsLaunchPlan{
cmd: exec.Command(binPath, args...),
proxyMode: "direct",
}
if parsed == nil {
return plan, nil
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
plan.effectiveProxyURL = normalized
plan.proxyMode = "env-http-proxy"
return plan, nil
case "socks5", "socks5h":
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
cfgPath, err := writeProxychainsConfig(parsed)
if err != nil {
return nil, err
}
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
plan.proxyMode = "proxychains4"
plan.cleanup = func() {
_ = os.Remove(cfgPath)
}
return plan, nil
}
effectiveProxyURL, err := prepareLSProxyURL(normalized)
if err != nil {
return nil, err
}
plan.effectiveProxyURL = effectiveProxyURL
plan.proxyMode = "http-connect-bridge"
return plan, nil
default:
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
}
}
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
content, err := buildProxychainsConfig(proxyURL)
if err != nil {
return "", err
}
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
if err != nil {
return "", fmt.Errorf("create proxychains config: %w", err)
}
if _, err := file.WriteString(content); err != nil {
_ = file.Close()
_ = os.Remove(file.Name())
return "", fmt.Errorf("write proxychains config: %w", err)
}
if err := file.Close(); err != nil {
_ = os.Remove(file.Name())
return "", fmt.Errorf("close proxychains config: %w", err)
}
return file.Name(), nil
}
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
if proxyURL == nil {
return "", fmt.Errorf("proxy url is nil")
}
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
}
host := strings.TrimSpace(proxyURL.Hostname())
port := strings.TrimSpace(proxyURL.Port())
if host == "" {
return "", fmt.Errorf("proxy host is empty")
}
if port == "" {
port = "1080"
}
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
}
var builder strings.Builder
builder.WriteString("strict_chain\n")
builder.WriteString("proxy_dns\n")
builder.WriteString("remote_dns_subnet 224\n")
builder.WriteString("tcp_connect_time_out 8000\n")
builder.WriteString("tcp_read_time_out 15000\n")
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
builder.WriteString("localnet ::1/128\n")
builder.WriteString("[ProxyList]\n")
builder.WriteString("socks5 ")
builder.WriteString(host)
builder.WriteString(" ")
builder.WriteString(port)
if username != "" {
builder.WriteString(" ")
builder.WriteString(username)
if password != "" {
builder.WriteString(" ")
builder.WriteString(password)
}
}
builder.WriteString("\n")
return builder.String(), nil
}

View File

@ -0,0 +1,31 @@
package lspool
import (
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
proxyURL, err := url.Parse("socks5h://gostuser:fastapipwd@216.167.85.31:1080")
require.NoError(t, err)
cfg, err := buildProxychainsConfig(proxyURL)
require.NoError(t, err)
require.Contains(t, cfg, "proxy_dns\n")
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
require.Contains(t, cfg, "localnet ::1/128\n")
require.Contains(t, cfg, "[ProxyList]\n")
require.Contains(t, cfg, "socks5 216.167.85.31 1080 gostuser fastapipwd\n")
}
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
require.NoError(t, err)
_, err = buildProxychainsConfig(proxyURL)
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "whitespace"))
}

View File

@ -0,0 +1,99 @@
package lspool
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
)
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("X-Worker-Token", i.workerToken)
if mode == "json" {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "application/octet-stream")
}
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("worker rpc read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
}
return respBody, nil
}
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("X-Worker-Token", i.workerToken)
if mode == "json" {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "application/octet-stream")
}
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
}
return resp, nil
}
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
base := url.URL{
Scheme: "http",
Host: i.Address,
Path: path,
}
values := url.Values{}
values.Set("service", service)
values.Set("method", method)
values.Set("mode", mode)
if i.routingKey != "" {
values.Set("routing_key", i.routingKey)
}
base.RawQuery = values.Encode()
return base.String(), nil
}
func marshalWorkerJSONBody(input any) ([]byte, error) {
if input == nil {
return []byte("{}"), nil
}
body, err := json.Marshal(input)
if err != nil {
return nil, err
}
return body, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,649 @@
package lspool
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
const (
lsWorkerManagedByLabel = "managed-by"
lsWorkerManagedByValue = "sub2api"
lsWorkerAccountLabel = "account_id"
lsWorkerProxyHashLabel = "proxy_hash"
lsWorkerImageTagLabel = "image_tag"
lsWorkerControlPort = 18081
)
type workerManagerConfig struct {
Image string
Network string
DockerSocket string
IdleTTL time.Duration
MaxActive int
StartupTimeout time.Duration
RequestTimeout time.Duration
}
type dockerClient interface {
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
Close() error
}
type workerManager struct {
cfg workerManagerConfig
docker dockerClient
http *http.Client
mu sync.Mutex
workers map[string]*workerHandle
state map[string]*workerAccountState
ctx context.Context
cancel context.CancelFunc
logger *slog.Logger
}
type workerHandle struct {
Key string
AccountID string
ProxyURL string
ProxyHash string
ContainerID string
Container string
Address string
AuthToken string
LastUsed time.Time
LastStateSHA string
}
type workerAccountState struct {
HasToken bool `json:"has_token"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
HasModelCredits bool `json:"has_model_credits"`
UseAICredits bool `json:"use_ai_credits"`
AvailableCredits *int32 `json:"available_credits,omitempty"`
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
}
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
managerCfg := workerManagerConfig{
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
}
if managerCfg.Image == "" {
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
}
if managerCfg.Network == "" {
managerCfg.Network = "sub2api-network"
}
if managerCfg.DockerSocket == "" {
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
}
if managerCfg.IdleTTL <= 0 {
managerCfg.IdleTTL = 15 * time.Minute
}
if managerCfg.MaxActive < 1 {
managerCfg.MaxActive = 50
}
if managerCfg.StartupTimeout <= 0 {
managerCfg.StartupTimeout = 45 * time.Second
}
if managerCfg.RequestTimeout <= 0 {
managerCfg.RequestTimeout = 60 * time.Second
}
opts := []client.Opt{client.WithAPIVersionNegotiation()}
if managerCfg.DockerSocket != "" {
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
} else {
opts = append(opts, client.FromEnv)
}
dockerClient, err := client.NewClientWithOpts(opts...)
if err != nil {
return nil, fmt.Errorf("create docker client: %w", err)
}
return newWorkerManager(managerCfg, dockerClient)
}
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
ctx, cancel := context.WithCancel(context.Background())
mgr := &workerManager{
cfg: cfg,
docker: docker,
http: &http.Client{
Timeout: cfg.RequestTimeout,
Transport: &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConnsPerHost: 8,
},
},
workers: make(map[string]*workerHandle),
state: make(map[string]*workerAccountState),
ctx: ctx,
cancel: cancel,
logger: slog.Default().With("component", "lspool-worker-manager"),
}
if err := mgr.reconcileManagedContainers(ctx); err != nil {
cancel()
_ = docker.Close()
return nil, err
}
go mgr.cleanupLoop()
return mgr, nil
}
func (m *workerManager) Close() {
m.cancel()
m.mu.Lock()
workers := make([]*workerHandle, 0, len(m.workers))
for _, handle := range m.workers {
workers = append(workers, handle)
}
m.workers = make(map[string]*workerHandle)
m.mu.Unlock()
for _, handle := range workers {
m.removeWorkerContainer(context.Background(), handle)
}
if m.docker != nil {
_ = m.docker.Close()
}
}
func (m *workerManager) Stats() map[string]any {
m.mu.Lock()
defer m.mu.Unlock()
return map[string]any{
"accounts": len(m.state),
"total": len(m.workers),
"active": len(m.workers),
}
}
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasToken = true
state.AccessToken = accessToken
state.RefreshToken = refreshToken
if expiresAt.IsZero() {
state.ExpiresAt = nil
} else {
ts := expiresAt.UTC()
state.ExpiresAt = &ts
}
}
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasModelCredits = true
state.UseAICredits = useAICredits
state.AvailableCredits = cloneInt32Ptr(availableCredits)
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
}
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
rawProxy := ""
if len(proxyURL) > 0 {
rawProxy = proxyURL[0]
}
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
if err != nil {
return nil, err
}
if parsedProxy == nil {
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
}
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
proxyHash := proxyHash(normalizedProxy)
workerKey := buildWorkerKey(accountID, proxyHash)
m.mu.Lock()
state := cloneWorkerAccountState(m.state[accountID])
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
}
handle := m.workers[workerKey]
if handle == nil {
if len(m.workers) >= m.cfg.MaxActive {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
}
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
if err != nil {
m.mu.Unlock()
return nil, err
}
m.workers[workerKey] = handle
}
handle.LastUsed = time.Now()
m.mu.Unlock()
if err := m.waitForWorkerHealthy(handle); err != nil {
return nil, err
}
if err := m.syncWorkerState(handle, state); err != nil {
return nil, err
}
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
return nil, err
}
inst := &Instance{
AccountID: accountID,
Replica: replica,
Address: handle.Address,
client: m.http,
healthy: true,
lastUsed: time.Now(),
modelMapReady: 1,
remote: true,
workerToken: handle.AuthToken,
routingKey: routingKey,
}
return inst, nil
}
func (m *workerManager) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.collectIdleWorkers()
}
}
}
func (m *workerManager) collectIdleWorkers() {
now := time.Now()
var expired []*workerHandle
m.mu.Lock()
for key, handle := range m.workers {
if handle == nil {
delete(m.workers, key)
continue
}
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
continue
}
expired = append(expired, handle)
delete(m.workers, key)
}
m.mu.Unlock()
for _, handle := range expired {
m.removeWorkerContainer(context.Background(), handle)
}
}
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
All: true,
Filters: args,
})
if err != nil {
return fmt.Errorf("list managed ls workers: %w", err)
}
for _, summary := range containers {
handle := &workerHandle{
ContainerID: summary.ID,
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
}
if err := m.removeWorkerContainer(ctx, handle); err != nil {
return err
}
}
return nil
}
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
authToken := generateUUID()
proxyHost := parsedProxy.Hostname()
proxyPort := parsedProxy.Port()
if proxyPort == "" {
proxyPort = "1080"
}
proxyUser := parsedProxy.User.Username()
proxyPass, _ := parsedProxy.User.Password()
labels := map[string]string{
lsWorkerManagedByLabel: lsWorkerManagedByValue,
lsWorkerAccountLabel: accountID,
lsWorkerProxyHashLabel: proxyHash,
lsWorkerImageTagLabel: m.cfg.Image,
}
env := []string{
"ANTIGRAVITY_APP_ROOT=/app/ls",
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
}
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
env = append(env, "TZ="+tz)
}
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
Image: m.cfg.Image,
Labels: labels,
Env: env,
}, &container.HostConfig{
CapAdd: []string{"NET_ADMIN"},
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
m.cfg.Network: {},
},
}, nil, containerName)
if err != nil {
return nil, fmt.Errorf("create ls worker container: %w", err)
}
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("start ls worker container: %w", err)
}
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("inspect ls worker container: %w", err)
}
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, err
}
m.logger.Info("created ls worker",
"account", shortAccountID(accountID),
"container", containerName,
"address", address,
"proxy_hash", proxyHash[:8])
return &workerHandle{
Key: buildWorkerKey(accountID, proxyHash),
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: proxyHash,
ContainerID: createResp.ID,
Container: containerName,
Address: address,
AuthToken: authToken,
LastUsed: time.Now(),
}, nil
}
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
if inspect.NetworkSettings == nil {
return "", fmt.Errorf("ls worker inspect missing network settings")
}
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
for _, endpoint := range inspect.NetworkSettings.Networks {
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
}
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
}
func firstContainerName(names []string) string {
if len(names) == 0 {
return ""
}
return names[0]
}
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
values := url.Values{}
if strings.TrimSpace(routingKey) != "" {
values.Set("routing_key", routingKey)
}
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if len(body) > 0 {
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
if state == nil {
return fmt.Errorf("ls worker state is nil")
}
body, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal worker state: %w", err)
}
sum := fmt.Sprintf("%x", sha256.Sum256(body))
if handle.LastStateSHA == sum {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err != nil {
return fmt.Errorf("sync worker state: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
}
handle.LastStateSHA = sum
return nil
}
func workerURL(handle *workerHandle, path string, values url.Values) string {
base := url.URL{
Scheme: "http",
Host: handle.Address,
Path: path,
}
if values != nil {
base.RawQuery = values.Encode()
}
return base.String()
}
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
return nil
}
timeout := 5
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
}
return nil
}
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
state := m.state[accountID]
if state == nil {
state = &workerAccountState{}
m.state[accountID] = state
}
return state
}
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
resolved := resolveLSProxy(proxyURL)
normalized, parsed, err := proxyurl.Parse(resolved)
if err != nil {
return "", nil, err
}
if parsed == nil {
return "", nil, nil
}
switch strings.ToLower(parsed.Scheme) {
case "socks5", "socks5h":
return normalized, parsed, nil
default:
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
}
}
func proxyHash(proxyURL string) string {
if strings.TrimSpace(proxyURL) == "" {
return "direct"
}
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
return fmt.Sprintf("%x", sum[:])
}
func buildWorkerKey(accountID, proxyHash string) string {
return accountID + ":" + proxyHash
}
func cloneInt32Ptr(v *int32) *int32 {
if v == nil {
return nil
}
cp := *v
return &cp
}
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
if state == nil {
return nil
}
cp := *state
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
if state.ExpiresAt != nil {
ts := *state.ExpiresAt
cp.ExpiresAt = &ts
}
return &cp
}

View File

@ -0,0 +1,266 @@
package lspool
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/stretchr/testify/require"
)
type fakeDockerClient struct {
mu sync.Mutex
listResp []container.Summary
listCalls int
createCalls int
startCalls int
stopCalls int
removeCalls int
inspectCalls int
removedIDs []string
createdConfigs []*container.Config
inspectResp container.InspectResponse
}
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.listCalls++
return append([]container.Summary(nil), f.listResp...), nil
}
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.createCalls++
f.createdConfigs = append(f.createdConfigs, cfg)
return container.CreateResponse{ID: "worker-created"}, nil
}
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.startCalls++
return nil
}
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.inspectCalls++
return f.inspectResp, nil
}
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.stopCalls++
return nil
}
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removeCalls++
f.removedIDs = append(f.removedIDs, containerID)
return nil
}
func (f *fakeDockerClient) Close() error { return nil }
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
require.Error(t, err)
require.Contains(t, err.Error(), "only supports socks5/socks5h")
}
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
require.NoError(t, err)
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
hash1 := proxyHash(normalized)
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
require.Equal(t, hash1, hash2)
}
func TestWorkerManagerRequiresToken(t *testing.T) {
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 2,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
var mu sync.Mutex
var healthCalls int
var readyCalls int
var stateCalls int
var stateBodies [][]byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/healthz":
mu.Lock()
healthCalls++
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
case "/readyz":
mu.Lock()
readyCalls++
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
case "/account/state":
body, _ := io.ReadAll(r.Body)
mu.Lock()
stateCalls++
stateBodies = append(stateBodies, body)
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 4,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
accountID := "9"
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
hash := proxyHash(proxyURL)
key := buildWorkerKey(accountID, hash)
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
manager.mu.Lock()
manager.workers[key] = &workerHandle{
Key: key,
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: hash,
ContainerID: "existing-worker",
Container: "sub2api-ls-9-test",
Address: strings.TrimPrefix(server.URL, "http://"),
AuthToken: "worker-token",
LastUsed: time.Now(),
}
manager.mu.Unlock()
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
require.NoError(t, err)
require.True(t, inst1.remote)
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
require.NoError(t, err)
require.True(t, inst2.remote)
mu.Lock()
defer mu.Unlock()
require.GreaterOrEqual(t, healthCalls, 2)
require.GreaterOrEqual(t, readyCalls, 2)
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
require.Len(t, stateBodies, 1)
var synced workerAccountState
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
require.True(t, synced.HasToken)
require.Equal(t, "ya29.test", synced.AccessToken)
}
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 1,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
manager.mu.Lock()
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
manager.mu.Unlock()
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
require.Error(t, err)
require.Contains(t, err.Error(), "limit reached")
require.Equal(t, 0, fakeDocker.createCalls)
}
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
fakeDocker := &fakeDockerClient{
listResp: []container.Summary{
{
ID: "old-worker-1",
Names: []string{"/sub2api-ls-9-deadbeef"},
},
{
ID: "old-worker-2",
Names: []string{"/sub2api-ls-10-beadfeed"},
},
},
}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 4,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
require.Equal(t, 1, fakeDocker.listCalls)
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
}
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
fakeDocker := &fakeDockerClient{}
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
require.NoError(t, err)
}

View File

@ -0,0 +1,368 @@
package lspool
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
type WorkerServerConfig struct {
AccountID string
AuthToken string
ListenAddr string
AppRoot string
NetworkReadyFile string
MaxIdleTime time.Duration
HealthInterval time.Duration
}
type WorkerServer struct {
cfg WorkerServerConfig
pool *Pool
logger *slog.Logger
mu sync.RWMutex
state workerAccountState
}
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
if strings.TrimSpace(cfg.AccountID) == "" {
return nil, fmt.Errorf("worker account id is required")
}
if strings.TrimSpace(cfg.AuthToken) == "" {
return nil, fmt.Errorf("worker auth token is required")
}
if strings.TrimSpace(cfg.ListenAddr) == "" {
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
}
if strings.TrimSpace(cfg.AppRoot) == "" {
cfg.AppRoot = "/app/ls"
}
if cfg.MaxIdleTime <= 0 {
cfg.MaxIdleTime = 15 * time.Minute
}
if cfg.HealthInterval <= 0 {
cfg.HealthInterval = 30 * time.Second
}
poolCfg := DefaultConfig()
poolCfg.AppRoot = cfg.AppRoot
poolCfg.MaxIdleTime = cfg.MaxIdleTime
poolCfg.HealthCheckInterval = cfg.HealthInterval
return &WorkerServer{
cfg: cfg,
pool: NewPool(poolCfg),
logger: slog.Default().With("component", "lsworker"),
}, nil
}
func NewWorkerServerFromEnv() (*WorkerServer, error) {
maxIdleTime := 15 * time.Minute
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
maxIdleTime = parsed
}
}
healthInterval := 30 * time.Second
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
healthInterval = parsed
}
}
return NewWorkerServer(WorkerServerConfig{
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
MaxIdleTime: maxIdleTime,
HealthInterval: healthInterval,
})
}
func (s *WorkerServer) Close() {
if s.pool != nil {
s.pool.Close()
}
}
func (s *WorkerServer) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", s.handleHealthz)
mux.HandleFunc("/readyz", s.handleReadyz)
mux.HandleFunc("/account/state", s.handleAccountState)
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
return mux
}
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
}
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
defer r.Body.Close()
var payload workerAccountState
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "invalid account state payload", http.StatusBadRequest)
return
}
s.mu.Lock()
s.state = *cloneWorkerAccountState(&payload)
s.mu.Unlock()
s.applyState()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
if len(body) == 0 {
body = []byte("{}")
}
var respBody []byte
switch mode {
case "json":
var input any
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
case "proto":
respBody, err = inst.CallRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(respBody)
}
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
var resp *http.Response
switch mode {
case "json":
var input any
if len(body) == 0 {
body = []byte("{}")
}
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
case "proto":
resp, err = inst.StreamRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
return true
}
http.Error(w, "unauthorized", http.StatusUnauthorized)
return false
}
func subtleHeaderEqual(left, right string) bool {
if left == "" || right == "" {
return false
}
return left == right
}
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return "", "", "", "", false
}
query := r.URL.Query()
service = strings.TrimSpace(query.Get("service"))
method = strings.TrimSpace(query.Get("method"))
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
routingKey = strings.TrimSpace(query.Get("routing_key"))
if service == "" || method == "" {
http.Error(w, "missing rpc target", http.StatusBadRequest)
return "", "", "", "", false
}
if mode == "" {
mode = "proto"
}
return service, method, mode, routingKey, true
}
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("worker network not ready: %w", err)
}
}
s.applyState()
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
return nil, fmt.Errorf("worker access token not configured")
}
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
if err != nil {
return nil, err
}
if inst.HasModelMappingReady() {
return inst, nil
}
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
defer cancel()
_ = modelCtx
if !RefreshModelMapping(inst) {
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
}
return inst, nil
}
func (s *WorkerServer) applyState() {
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil {
return
}
if state.HasToken {
expiresAt := time.Time{}
if state.ExpiresAt != nil {
expiresAt = state.ExpiresAt.UTC()
}
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
}
if state.HasModelCredits {
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
}
}
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
return &http.Server{
Addr: listenAddr,
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
}
func workerExitCode(err error) int {
if err == nil {
return 0
}
return 1
}
func parseWorkerControlPort() int {
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
if raw == "" {
return lsWorkerControlPort
}
port, err := strconv.Atoi(raw)
if err != nil || port < 1 {
return lsWorkerControlPort
}
return port
}

View File

@ -13,6 +13,8 @@ import (
"strings"
"sync"
"sync/atomic"
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
"time"
"github.com/andybalholm/brotli"
@ -99,16 +101,34 @@ type httpUpstreamService struct {
// NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 当环境变量 ANTIGRAVITY_LS_MODE=true 时,自动包装 LS 池拦截层:
// - 仅对已知兼容的 LS 请求形态启用转发
// - 对普通 streamGenerateContent 请求保留原有直连路径,避免误送到不兼容的 LS RPC
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
return &httpUpstreamService{
base := &httpUpstreamService{
cfg: cfg,
clients: make(map[string]*upstreamClientEntry),
}
// LS 池模式: 包装一层拦截, streamGenerateContent 走 LS
if lspool.IsLSModeEnabled() {
pool := lspool.GlobalPool(cfg)
if pool != nil {
slog.Info("LS pool mode enabled — streamGenerateContent will route through Language Server",
"component", "http_upstream")
return lspool.NewLSPoolUpstream(pool, base)
}
slog.Warn("LS pool mode enabled but pool is nil — falling back to direct mode",
"component", "http_upstream")
}
return base
}
// Do 执行 HTTP 请求

View File

@ -104,6 +104,10 @@ func classifyAntigravity429(body []byte) antigravity429Category {
return antigravity429QuotaExhausted
}
}
if strings.Contains(lowerBody, "exhausted your capacity on this model") &&
strings.Contains(lowerBody, "quota will reset after") {
return antigravity429QuotaExhausted
}
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
return antigravity429RateLimited
}

View File

@ -21,6 +21,16 @@ func TestClassifyAntigravity429(t *testing.T) {
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
})
t.Run("模型配额耗尽文案也视为可切 AI Credits", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
}
}`)
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
})
t.Run("结构化限流", func(t *testing.T) {
body := []byte(`{
"error": {
@ -146,6 +156,68 @@ func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
}
func TestHandleSmartRetry_ModelQuotaMessage_UsesCredits(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 151,
Name: "acc-151",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-opus-4-6": "claude-opus-4-6",
},
},
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-opus-4-6",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
}
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,

View File

@ -112,6 +112,144 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError,
return nil, false
}
// injectLSPoolHeaders adds internal headers carrying OAuth credentials for the
// LS pool layer. These headers are consumed and stripped by LSPoolUpstream
// before the request reaches the Language Server. When LS mode is disabled,
// these headers are harmless — the direct upstream never sees them because
// they are stripped inside LSPoolUpstream.Do(). In direct mode the request
// goes straight through httpUpstreamService.Do() which doesn't inspect them.
func injectLSPoolHeaders(req *http.Request, account *Account) {
if req == nil || account == nil {
return
}
if rt, ok := account.Credentials["refresh_token"].(string); ok && rt != "" {
req.Header.Set("X-Antigravity-Refresh-Token", rt)
}
if ea, ok := account.Credentials["expires_at"].(string); ok && ea != "" {
req.Header.Set("X-Antigravity-Token-Expiry", ea)
}
req.Header.Set("X-Antigravity-Use-AI-Credits", strconv.FormatBool(account.IsOveragesEnabled()))
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
if availableCredits != nil {
req.Header.Set("X-Antigravity-Available-Credits", strconv.FormatInt(int64(*availableCredits), 10))
}
if minimumCreditAmount != nil {
req.Header.Set("X-Antigravity-Minimum-Credit-Amount", strconv.FormatInt(int64(*minimumCreditAmount), 10))
}
}
func resolveLSPoolModelCreditsState(account *Account) (*int32, *int32) {
if account == nil || account.Extra == nil {
minimum := int32(50)
return nil, &minimum
}
var availableCredits *int32
var minimumCreditAmount *int32
collect := func(entry map[string]any) {
if entry == nil {
return
}
if !isGoogleOneAICreditsEntry(entry) {
return
}
if availableCredits == nil {
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "Amount", "amount", "creditAmount")); ok {
availableCredits = &parsed
}
}
if minimumCreditAmount == nil {
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "MinimumBalance", "minimum_balance", "minimumCreditAmountForUsage")); ok {
minimumCreditAmount = &parsed
}
}
}
if rawCredits, ok := account.Extra["ai_credits"].([]any); ok {
for _, item := range rawCredits {
if entry, ok := item.(map[string]any); ok {
collect(entry)
}
}
}
if loadCodeAssist, ok := account.Extra["load_code_assist"].(map[string]any); ok {
if paidTier, ok := loadCodeAssist["paidTier"].(map[string]any); ok {
if credits, ok := paidTier["availableCredits"].([]any); ok {
for _, item := range credits {
if entry, ok := item.(map[string]any); ok {
collect(entry)
}
}
}
}
}
if minimumCreditAmount == nil {
defaultMinimum := int32(50)
minimumCreditAmount = &defaultMinimum
}
return availableCredits, minimumCreditAmount
}
func isGoogleOneAICreditsEntry(entry map[string]any) bool {
creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string)
creditType = strings.TrimSpace(strings.ToUpper(creditType))
return creditType == "" || creditType == "GOOGLE_ONE_AI"
}
func firstPresent(entry map[string]any, keys ...string) any {
for _, key := range keys {
if value, ok := entry[key]; ok {
return value
}
}
return nil
}
func parseAICreditsInt32(raw any) (int32, bool) {
switch v := raw.(type) {
case int:
return int32(v), true
case int32:
return v, true
case int64:
return int32(v), true
case float32:
return int32(v), true
case float64:
return int32(v), true
case json.Number:
parsed, err := v.Int64()
if err != nil {
floatVal, floatErr := strconv.ParseFloat(v.String(), 64)
if floatErr != nil {
return 0, false
}
return int32(floatVal), true
}
return int32(parsed), true
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return 0, false
}
parsed, err := strconv.ParseInt(trimmed, 10, 32)
if err == nil {
return int32(parsed), true
}
floatVal, floatErr := strconv.ParseFloat(trimmed, 64)
if floatErr != nil {
return 0, false
}
return int32(floatVal), true
default:
return 0, false
}
}
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
@ -305,6 +443,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
}
injectLSPoolHeaders(retryReq, p.account)
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
@ -489,6 +628,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
break
}
injectLSPoolHeaders(retryReq, p.account)
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
@ -627,6 +767,7 @@ urlFallbackLoop:
if err != nil {
return nil, err
}
injectLSPoolHeaders(upstreamReq, p.account)
// Capture upstream request body for ops retry of this attempt.
if p.c != nil && len(p.body) > 0 {
@ -1289,9 +1430,19 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) {
sessionID := ""
if len(preferredSessionID) > 0 {
sessionID = preferredSessionID[0]
}
bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID)
if err != nil {
return nil, fmt.Errorf("补全 sessionId 失败: %w", err)
}
var request any
if err := json.Unmarshal(originalBody, &request); err != nil {
if err := json.Unmarshal(bodyWithSessionID, &request); err != nil {
return nil, fmt.Errorf("解析请求体失败: %w", err)
}
@ -2156,7 +2307,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
}
@ -2220,10 +2371,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
injectLSPoolHeaders(fallbackReq, account)
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
_ = resp.Body.Close()
@ -2263,7 +2415,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID)
if wrapErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,

View File

@ -600,6 +600,63 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
require.Equal(t, mappedModel, result.UpstreamModel)
}
func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("session_id", "session-header-1")
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-session-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
account := &Account{
ID: 16,
Name: "acc-gemini-session",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
var wrapped map[string]any
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
requestNode, ok := wrapped["request"].(map[string]any)
require.True(t, ok)
require.Equal(t, "session-header-1", requestNode["sessionId"])
}
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()

View File

@ -2500,6 +2500,57 @@
"supports_vision": true,
"supports_web_search": true
},
"gemini-3-flash": {
"cache_read_input_token_cost": 5e-08,
"cache_read_input_token_cost_priority": 9e-08,
"input_cost_per_audio_token": 1e-06,
"input_cost_per_audio_token_priority": 1.8e-06,
"input_cost_per_token": 5e-07,
"input_cost_per_token_priority": 9e-07,
"litellm_provider": "vertex_ai-language-models",
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_images_per_prompt": 3000,
"max_input_tokens": 1048576,
"max_output_tokens": 65535,
"max_pdf_size_mb": 30,
"max_tokens": 65535,
"max_video_length": 1,
"max_videos_per_prompt": 10,
"mode": "chat",
"output_cost_per_reasoning_token": 3e-06,
"output_cost_per_token": 3e-06,
"output_cost_per_token_priority": 5.4e-06,
"source": "https://ai.google.dev/pricing/gemini-3",
"supported_endpoints": [
"/v1/chat/completions",
"/v1/completions",
"/v1/batch"
],
"supported_modalities": [
"text",
"image",
"audio",
"video"
],
"supported_output_modalities": [
"text"
],
"supports_audio_output": false,
"supports_function_calling": true,
"supports_native_streaming": true,
"supports_parallel_function_calling": true,
"supports_pdf_input": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_response_schema": true,
"supports_service_tier": true,
"supports_system_messages": true,
"supports_tool_choice": true,
"supports_url_context": true,
"supports_vision": true,
"supports_web_search": true
},
"gemini-3-flash-preview": {
"cache_read_input_token_cost": 5e-08,
"cache_read_input_token_cost_priority": 9e-08,

View File

@ -20,6 +20,9 @@ SERVER_PORT=8080
# Server mode: release or debug
SERVER_MODE=release
# Main application image override
SUB2API_IMAGE=zfc931912343/sub2api:latest
# -----------------------------------------------------------------------------
# Logging Configuration
# 日志配置
@ -389,3 +392,32 @@ OPS_ENABLED=true
# Leave empty for direct connection (recommended for overseas servers)
# 留空表示直连(适用于海外服务器)
UPDATE_PROXY_URL=
# -----------------------------------------------------------------------------
# Language Server Pool Mode (Enhanced Security)
# -----------------------------------------------------------------------------
# Enable to route requests through real AntiGravity LS binary
# Makes upstream traffic indistinguishable from real IDE
# ANTIGRAVITY_LS_MODE=true
# LS replicas per account. Default is 5.
# Increase for higher concurrency, but each replica is an extra LS process.
# ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=5
# Optional global fallback proxy for accounts without dedicated LS proxy.
# Must be socks5/socks5h in worker mode.
ANTIGRAVITY_LS_PROXY=
# LS routing strategy (default js-parity)
ANTIGRAVITY_LS_STRATEGY=js-parity
# Dynamic LS worker container image. Build/pull this image before enabling LS mode.
GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=zfc931912343/sub2api-lsworker:latest
# Docker network name shared by sub2api and dynamic ls-worker containers.
GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=sub2api-network
# Docker socket used by sub2api to create dynamic ls-worker containers.
GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=unix:///var/run/docker.sock
# Idle TTL before worker container is reaped.
GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=15m
# Maximum number of active worker containers on this node.
GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=50
# Maximum time allowed for worker cold start and readiness.
GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=45s
# Per-request timeout when sub2api talks to worker control API.
GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=60s

View File

@ -65,6 +65,20 @@ docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
# http://localhost:8080
```
### LS Worker Image
When `ANTIGRAVITY_LS_MODE=true`, Sub2API creates dynamic `ls-worker`
containers through the Docker socket. Build or pull the worker image before
enabling LS mode:
```bash
cd /path/to/sub2api
docker build -f deploy/lsworker.Dockerfile -t weishaw/sub2api-lsworker:latest .
```
The `sub2api` container must also be able to access `/var/run/docker.sock`,
and the shared Docker network name must remain fixed at `sub2api-network`.
### Method 2: Manual Deployment
If you prefer manual control:

View File

@ -283,6 +283,30 @@ gateway:
queue: 0.7
error_rate: 0.8
ttft: 0.5
# Antigravity LS worker container configuration
# Antigravity LS worker 容器控制平面配置
antigravity_ls_worker:
# Worker image used by sub2api to create dynamic LS containers
# sub2api 用于创建动态 LS worker 的镜像
image: "weishaw/sub2api-lsworker:latest"
# Docker network name shared by sub2api and workers
# sub2api 与 worker 共享的 Docker network 名称
network: "sub2api-network"
# Docker socket path or host used by sub2api control plane
# sub2api 控制面访问的 Docker socket / host
docker_socket: "unix:///var/run/docker.sock"
# Idle TTL before a worker container is recycled
# worker 容器空闲回收时间
idle_ttl: 15m
# Max active worker containers per node
# 单节点最大 worker 容器数量
max_active: 50
# Worker cold-start timeout
# worker 冷启动超时
startup_timeout: 45s
# Timeout for control-plane calls from sub2api to worker
# sub2api 调用 worker 控制接口的超时
request_timeout: 60s
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts

View File

@ -36,6 +36,7 @@ services:
volumes:
# Local directory mapping for easy migration
- ./data:/app/data
- /var/run/docker.sock:/var/run/docker.sock
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml
@ -128,6 +129,22 @@ services:
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
# =======================================================================
# Language Server Worker Mode
# =======================================================================
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
- ANTIGRAVITY_APP_ROOT=/app/ls
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
@ -230,4 +247,5 @@ services:
# =============================================================================
networks:
sub2api-network:
name: sub2api-network
driver: bridge

View File

@ -16,7 +16,8 @@ services:
# Sub2API Application
# ===========================================================================
sub2api:
image: weishaw/sub2api:latest
# Override with SUB2API_IMAGE to use a private registry or pinned tag.
image: ${SUB2API_IMAGE:-weishaw/sub2api:latest}
container_name: sub2api
restart: unless-stopped
ulimits:
@ -28,6 +29,7 @@ services:
volumes:
# Data persistence (config.yaml will be auto-generated here)
- sub2api_data:/app/data
- /var/run/docker.sock:/var/run/docker.sock
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml
@ -120,6 +122,26 @@ services:
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
# =======================================================================
# Language Server Pool Mode (Enhanced Security)
# =======================================================================
# Enable to route requests through real LS binary (Google's own code)
# This makes upstream traffic indistinguishable from real IDE
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
- ANTIGRAVITY_APP_ROOT=/app/ls
# SOCKS5/HTTP proxy fallback used when account has no dedicated LS proxy
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
# Keep the worker image aligned with the main image release when overriding.
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
# =======================================================================
# Security Configuration (URL Allowlist)
# =======================================================================
@ -234,4 +256,5 @@ volumes:
# =============================================================================
networks:
sub2api-network:
name: sub2api-network
driver: bridge

View File

@ -8,9 +8,27 @@ if [ "$(id -u)" = "0" ]; then
mkdir -p /app/data
# Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro)
chown -R sub2api:sub2api /app/data 2>/dev/null || true
if [ -S /var/run/docker.sock ]; then
DOCKER_GID="$(stat -c '%g' /var/run/docker.sock 2>/dev/null || true)"
if [ -n "${DOCKER_GID}" ]; then
DOCKER_GROUP="$(getent group "${DOCKER_GID}" | cut -d: -f1 || true)"
if [ -z "${DOCKER_GROUP}" ]; then
DOCKER_GROUP="dockersock"
groupadd -for -g "${DOCKER_GID}" "${DOCKER_GROUP}" 2>/dev/null || true
fi
usermod -aG "${DOCKER_GROUP}" sub2api 2>/dev/null || true
fi
fi
# Re-invoke this script as sub2api so the flag-detection below
# also runs under the correct user.
exec su-exec sub2api "$0" "$@"
# Use gosu if available (Debian), fall back to su-exec (Alpine)
if command -v gosu >/dev/null 2>&1; then
exec gosu sub2api "$0" "$@"
elif command -v su-exec >/dev/null 2>&1; then
exec su-exec sub2api "$0" "$@"
else
exec su -s /bin/sh sub2api -c "exec $0 $*"
fi
fi
# Compatibility: if the first arg looks like a flag (e.g. --help),

21
deploy/ls-bin/cert.pem Normal file
View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
/Q==
-----END CERTIFICATE-----

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,70 @@
#!/bin/sh
set -eu
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
PROXY_USER="${LSWORKER_PROXY_USER:-}"
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
if [ -z "${PROXY_HOST}" ]; then
echo "LSWORKER_PROXY_HOST is required" >&2
exit 1
fi
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
if [ -z "${PROXY_IP}" ]; then
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
exit 1
fi
cat >/tmp/redsocks.conf <<EOF
base {
log_debug = off;
log_info = on;
daemon = off;
redirector = iptables;
}
redsocks {
local_ip = 0.0.0.0;
local_port = ${REDSOCKS_PORT};
ip = ${PROXY_IP};
port = ${PROXY_PORT};
type = socks5;
EOF
if [ -n "${PROXY_USER}" ]; then
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
fi
if [ -n "${PROXY_PASS}" ]; then
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
fi
cat >>/tmp/redsocks.conf <<EOF
}
EOF
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
REDSOCKS_PID="$!"
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
sleep 1
iptables -t nat -N REDSOCKS 2>/dev/null || true
iptables -t nat -F REDSOCKS
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
touch "${NETWORK_READY_FILE}"
exec gosu sub2api /app/lsworker

View File

@ -0,0 +1,52 @@
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG DEBIAN_IMAGE=debian:bookworm-slim
FROM ${GOLANG_IMAGE} AS builder
WORKDIR /app/backend
RUN apk add --no-cache git ca-certificates tzdata
COPY backend/go.mod backend/go.sum ./
RUN go mod download
COPY backend/ ./
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
FROM ${DEBIAN_IMAGE}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
gosu \
iproute2 \
iptables \
redsocks \
tzdata \
&& rm -rf /var/lib/apt/lists/*
RUN groupadd -g 1000 sub2api && \
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
WORKDIR /app
COPY --from=builder /app/lsworker /app/lsworker
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
ARG TARGETARCH
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
else \
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
fi && \
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
chown -R sub2api:sub2api /app /run/lsworker && \
rm -rf /tmp/ls-bin
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
RUN chmod +x /app/lsworker-entrypoint.sh
EXPOSE 18081
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]

View File

@ -510,6 +510,7 @@ const handleEvent = (event: {
addLine(streamingContent.value, 'text-green-300')
streamingContent.value = ''
}
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
break
}
}

View File

@ -510,6 +510,7 @@ const handleEvent = (event: {
addLine(streamingContent.value, 'text-green-300')
streamingContent.value = ''
}
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
break
}
}

View File

@ -144,4 +144,28 @@ describe('AccountTestModal', () => {
expect(preview.exists()).toBe(true)
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
})
it('收到 error 事件时会把错误内容显示在终端输出里', async () => {
;(global.fetch as any).mockResolvedValueOnce(
createStreamResponse([
'data: {"type":"test_start","model":"claude-opus-4-6"}\n',
'data: {"type":"error","error":"API returned 429: You have exhausted your capacity on this model."}\n'
])
)
const wrapper = mountModal()
await wrapper.setProps({ show: true })
await flushPromises()
const buttons = wrapper.findAll('button')
const startButton = buttons.find((button) => button.text().includes('admin.accounts.startTest'))
expect(startButton).toBeTruthy()
await startButton!.trigger('click')
await flushPromises()
await flushPromises()
expect(wrapper.text()).toContain('API returned 429: You have exhausted your capacity on this model.')
expect(wrapper.text()).toContain('Error: API returned 429: You have exhausted your capacity on this model.')
})
})