chore: remove LS pool implementation

Removing all LS (Language Server Pool) related code:
- backend/cmd/lsworker/
- backend/internal/pkg/lspool/
- backend/internal/service/lspool_bootstrap_service.*
- deploy/ls-bin/
- deploy/lsworker.Dockerfile
- deploy/lsworker-entrypoint.sh

Keeping:
- Claude custom fingerprint (immutable)
- Antigravity OAuth and telemetry improvements
- TLS fingerprint SOCKS5 Docker DNS fix
- Gemini OAuth security improvements

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
win 2026-04-08 23:43:05 +08:00
parent 3ba3a17652
commit a3f2d4577e
31 changed files with 186 additions and 7977 deletions

View File

@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG GOLANG_IMAGE=golang:1.25-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG DEBIAN_IMAGE=debian:bookworm-slim
ARG POSTGRES_IMAGE=postgres:18-alpine

View File

@ -1,4 +1,4 @@
FROM golang:1.25.7-alpine
FROM golang:1.25-alpine
WORKDIR /app

View File

@ -1,49 +0,0 @@
package main
import (
"context"
"errors"
"log/slog"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
)
func main() {
server, err := lspool.NewWorkerServerFromEnv()
if err != nil {
slog.Error("failed to initialize lsworker", "err", err)
os.Exit(1)
}
defer server.Close()
httpServer := &http.Server{
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
Handler: server.Handler(),
ReadHeaderTimeout: 10 * 1e9,
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
go func() {
<-ctx.Done()
_ = httpServer.Shutdown(context.Background())
}()
slog.Info("lsworker listening", "addr", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Error("lsworker exited with error", "err", err)
os.Exit(1)
}
}
func envOrDefault(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}

View File

@ -53,9 +53,8 @@ const (
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
var defaultUserAgentVersion = "1.107.0"
// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret string
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖
var defaultClientSecret = "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"
func init() {
// 从环境变量读取版本号,未设置则使用默认值

View File

@ -1,6 +1,8 @@
// Package claude provides constants and helpers for Claude API integration.
package claude
import "strings"
// Claude Code 客户端相关常量
// DefaultCLIVersion 是当前模拟的 Claude CLI 版本
@ -30,32 +32,64 @@ const (
// 这些 token 是客户端特有的,不应透传给上游 API。
var DroppedBetas = []string{}
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta headerOAuth 账号,不含 context-1m
// 使用 GetOAuthBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta headerOAuth不含 context-1m
//
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
// even if the request doesn't use tools, otherwise upstream may reject the
// request as a non-Claude-Code API request.
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta headerOAuth不含 context-1m
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta headerOAuth不含 claude-code / context-1m
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
// APIKeyBetaHeader API-key 账号使用的 anthropic-beta header不含 oauth / context-1m
// 使用 GetAPIKeyBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaEffort + "," + BetaPromptCachingScope
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header含 oauth / claude-code
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort
// ModelSupports1M 判断模型是否支持 1M context window。
// 与 claude-code-2.1.88 bundle 中 modelSupports1M 逻辑保持一致:
//
// claude-sonnet-4 系列 和 claude-opus-4-6 支持 1M context。
func ModelSupports1M(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
return strings.Contains(lower, "claude-sonnet-4") || strings.Contains(lower, "opus-4-6")
}
// GetOAuthBetaHeader 返回 OAuth 账号的 beta header。
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
func GetOAuthBetaHeader(modelID string) string {
if ModelSupports1M(modelID) {
return BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
}
return DefaultBetaHeader
}
// GetAPIKeyBetaHeader 返回 API-key 账号的 beta header。
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
func GetAPIKeyBetaHeader(modelID string) string {
if strings.Contains(strings.ToLower(modelID), "haiku") {
return APIKeyHaikuBetaHeader
}
if ModelSupports1M(modelID) {
return BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
}
return APIKeyBetaHeader
}
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
@ -70,7 +104,7 @@ var DefaultHeaders = map[string]string{
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
"anthropic-version": "2023-06-01",
}
// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值)

View File

@ -1,13 +0,0 @@
package lspool
import "time"
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
// It may be backed by a local in-process Pool or by remote LS workers.
type Backend interface {
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
Stats() map[string]any
Close()
}

View File

@ -1,94 +0,0 @@
// Package lspool provides LS-mode integration for the antigravity gateway.
//
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
// streamGenerateContent are routed through a real Language Server instance
// instead of directly to cloudcode-pa. This provides:
//
// - Authentic TLS fingerprint (Google's own Go binary)
// - Real session management and Heartbeat
// - Indistinguishable from a real IDE instance
//
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
package lspool
import (
"log/slog"
"os"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var (
globalBackend Backend
globalPoolOnce sync.Once
lsModeEnabled bool
)
func init() {
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
}
// IsLSModeEnabled returns whether LS mode is active
func IsLSModeEnabled() bool {
return lsModeEnabled
}
const (
LSStrategyDirect = "direct"
LSStrategyJSParity = "js-parity"
)
// CurrentLSStrategy returns the active LS routing strategy.
// Unknown values are treated as "direct" for safety.
func CurrentLSStrategy() string {
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
case "", LSStrategyDirect:
return LSStrategyDirect
case LSStrategyJSParity:
return LSStrategyJSParity
default:
return LSStrategyDirect
}
}
// GlobalPool returns the singleton LS pool instance
// Creates it on first call if LS mode is enabled
func GlobalPool(cfg *config.Config) Backend {
if !lsModeEnabled {
return nil
}
globalPoolOnce.Do(func() {
manager, err := NewWorkerManagerFromConfig(cfg)
if err != nil {
slog.Default().Error("failed to initialize LS worker manager", "err", err)
return
}
globalBackend = manager
})
return globalBackend
}
// Shutdown closes the global pool
func Shutdown() {
if globalBackend != nil {
globalBackend.Close()
}
}
// StatusInfo returns the current LS pool status for diagnostics
func StatusInfo() map[string]any {
info := map[string]any{
"ls_mode_enabled": lsModeEnabled,
"build": "enhanced",
"user_agent": "antigravity/1.107.0",
}
if lsModeEnabled && globalBackend != nil {
stats := globalBackend.Stats()
info["pool_total"] = stats["total"]
info["pool_active"] = stats["active"]
}
return info
}

View File

@ -1,864 +0,0 @@
package lspool
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func readConnectFrame(r io.Reader) ([]byte, error) {
header := make([]byte, 5)
if _, err := io.ReadFull(r, header); err != nil {
return nil, err
}
payloadLen := binary.BigEndian.Uint32(header[1:5])
payload := make([]byte, payloadLen)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
return payload, nil
}
func decodeProtoBytesField(data []byte, targetField int) []byte {
i := 0
for i < len(data) {
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0:
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
case 2:
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return nil
}
i += n
if i+int(length) > len(data) {
return nil
}
if fieldNum == targetField {
return data[i : i+int(length)]
}
i += int(length)
case 1:
i += 8
case 5:
i += 4
default:
return nil
}
}
return nil
}
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
var values [][]byte
i := 0
for i < len(data) {
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0:
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
case 2:
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return values
}
i += n
if i+int(length) > len(data) {
return values
}
if fieldNum == targetField {
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
}
i += int(length)
case 1:
i += 8
case 5:
i += 4
default:
return values
}
}
return values
}
func decodeTopicRows(topic []byte) map[string]string {
rows := make(map[string]string)
for _, entry := range decodeProtoBytesFields(topic, 1) {
key := decodeProtoString(entry, 1)
row := decodeProtoBytesField(entry, 2)
rows[key] = decodeProtoString(row, 1)
}
return rows
}
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
t.Helper()
decoded, err := base64.StdEncoding.DecodeString(got)
require.NoError(t, err)
require.Equal(t, want, decoded)
}
// TestMockExtensionServerTokenInjection verifies the token injection flow:
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
func TestMockExtensionServerTokenInjection(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
// 1. Set token for an account
srv.SetToken("account-1", &TokenInfo{
AccessToken: "ya29.test-access-token",
RefreshToken: "1//test-refresh-token",
ExpiresAt: time.Now().Add(1 * time.Hour),
})
// 2. Verify token is stored
srv.mu.RLock()
info, ok := srv.tokens["account-1"]
srv.mu.RUnlock()
require.True(t, ok)
require.Equal(t, "ya29.test-access-token", info.AccessToken)
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
// The stream handler will block, so run in background and cancel after we confirm connection
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
// Read the first envelope frame (initial state)
header := make([]byte, 5)
n, readErr := resp.Body.Read(header)
if readErr == nil && n == 5 {
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
}
}
}
// TestMockExtensionServerCSRF verifies CSRF token validation
func TestMockExtensionServerCSRF(t *testing.T) {
csrf := "correct-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
// Wrong CSRF → 403
req, _ := http.NewRequest("POST", base, nil)
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 403, resp.StatusCode)
// Correct CSRF → 200
req2, _ := http.NewRequest("POST", base, nil)
req2.Header.Set("x-codeium-csrf-token", csrf)
req2.Header.Set("Content-Type", "application/proto")
resp2, err := http.DefaultClient.Do(req2)
require.NoError(t, err)
defer resp2.Body.Close()
require.Equal(t, 200, resp2.StatusCode)
}
// TestMockExtensionServerGetSecretValue verifies the fallback token path
func TestMockExtensionServerGetSecretValue(t *testing.T) {
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
// GetSecretValue should return the token
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
nil)
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/proto")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
}
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
func TestOAuthTokenInfoProto(t *testing.T) {
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
// Verify fields are present by checking proto wire format
require.True(t, len(bin) > 0, "proto should not be empty")
// Field 1 (access_token): tag=0x0a, value="ya29.test"
require.Contains(t, string(bin), "ya29.test")
// Field 2 (token_type): tag=0x12, value="Bearer"
require.Contains(t, string(bin), "Bearer")
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
require.Contains(t, string(bin), "1//refresh")
// Without refresh_token
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
require.NotContains(t, string(binNoRefresh), "1//refresh")
}
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
future := time.Now().Add(2 * time.Hour)
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
// Zero expiry should default to ~1h
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
// They should be different lengths or content (different expiry timestamps)
// Both should be valid (non-empty)
require.True(t, len(bin) > 0)
require.True(t, len(binZero) > 0)
}
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
func TestUSSTopicWithOAuth(t *testing.T) {
expiry := time.Now().Add(1 * time.Hour)
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
require.True(t, len(topic) > 0)
// The topic should contain the sentinel key
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
}
func TestUSSTopicWithModelCredits(t *testing.T) {
available := int32(123)
minimum := int32(50)
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
require.True(t, len(topic) > 0)
require.Contains(t, string(topic), useAICreditsSentinelKey)
require.Contains(t, string(topic), availableCreditsSentinelKey)
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
rows := decodeTopicRows(topic)
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
}
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
csrf := "test-csrf-token"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
req, _ := http.NewRequest("POST",
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
req.Header.Set("x-codeium-csrf-token", csrf)
req.Header.Set("Content-Type", "application/connect+proto")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
req = req.WithContext(ctx)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, 200, resp.StatusCode)
// Drain the initial_state frame first.
_, err = readConnectFrame(resp.Body)
require.NoError(t, err)
available := int32(77)
minimum := int32(25)
srv.SetModelCredits("account-1", &ModelCreditsInfo{
UseAICredits: true,
AvailableCredits: &available,
MinimumCreditAmountForUsage: &minimum,
})
values := make(map[string]string, 3)
for len(values) < 3 {
frame, readErr := readConnectFrame(resp.Body)
require.NoError(t, readErr)
applied := decodeProtoBytesField(frame, 2)
require.NotEmpty(t, applied)
key := decodeProtoString(applied, 1)
row := decodeProtoBytesField(applied, 2)
values[key] = decodeProtoString(row, 1)
}
require.Contains(t, values, useAICreditsSentinelKey)
require.Contains(t, values, availableCreditsSentinelKey)
require.Contains(t, values, minimumCreditAmountForUsageKey)
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
}
// TestBuildInitialStateUpdate verifies the USS update wrapper
func TestBuildInitialStateUpdate(t *testing.T) {
topicData := buildEmptyTopic()
update := buildInitialStateUpdate(topicData)
// Should be a valid proto bytes field (field 1 = initial_state)
require.True(t, len(update) >= 0) // empty topic is valid
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
update2 := buildInitialStateUpdate(topicData2)
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
}
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
func TestPoolSetAccountTokenComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
expiry := time.Now().Add(1 * time.Hour)
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
srv.mu.RLock()
info := srv.tokens["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.Equal(t, "ya29.full-token", info.AccessToken)
require.Equal(t, "1//full-refresh", info.RefreshToken)
require.False(t, info.ExpiresAt.IsZero())
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
}
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
csrf := "pool-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
ctx: ctx,
cancel: cancel,
}
available := int32(77)
minimum := int32(25)
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
srv.mu.RLock()
info := srv.credits["acc-1"]
srv.mu.RUnlock()
require.NotNil(t, info)
require.True(t, info.UseAICredits)
require.NotNil(t, info.AvailableCredits)
require.Equal(t, available, *info.AvailableCredits)
require.NotNil(t, info.MinimumCreditAmountForUsage)
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
}
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
// Create a mock upstream that records what it receives
var receivedHeaders http.Header
var mu sync.Mutex
fallback := &recordingUpstreamWithCallback{}
fallback.onDo = func(req *http.Request) {
mu.Lock()
receivedHeaders = req.Header.Clone()
mu.Unlock()
}
csrf := "test-csrf"
srv, err := NewMockExtensionServer(csrf)
require.NoError(t, err)
defer srv.Close()
pool := &Pool{
config: DefaultConfig(),
instances: make(map[string][]*Instance),
extServer: srv,
}
upstream := NewLSPoolUpstream(pool, fallback)
// Non-streamGenerateContent request → should pass through to fallback
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
req.Header.Set("Authorization", "Bearer ya29.test")
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
req.Header.Set(useAICreditsHeader, "true")
req.Header.Set(availableCreditsHeader, "42")
req.Header.Set(minimumCreditAmountHeader, "50")
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
// Internal headers should never leak to the direct upstream.
mu.Lock()
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
mu.Unlock()
srv.mu.RLock()
tokenInfo := srv.tokens["1"]
creditsInfo := srv.credits["1"]
srv.mu.RUnlock()
require.NotNil(t, tokenInfo)
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
require.NotNil(t, creditsInfo)
require.True(t, creditsInfo.UseAICredits)
require.NotNil(t, creditsInfo.AvailableCredits)
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
}
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
body := `{
"model": "claude-sonnet-4-6",
"request": {
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
"contents": [
{"role": "user", "parts": [{"text": "Hello"}]},
{"role": "model", "parts": [{"text": "Hi there!"}]},
{"role": "user", "parts": [{"text": "How are you?"}]}
]
}
}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "claude-sonnet-4-6", model)
require.Contains(t, prompt, "You are helpful")
require.Contains(t, prompt, "Hello")
require.Contains(t, prompt, "Hi there!")
require.Contains(t, prompt, "How are you?")
}
// TestExtractUsageFromTrajectory verifies token usage extraction
func TestExtractUsageFromTrajectory(t *testing.T) {
resp := `{
"trajectory": {
"steps": [{
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
"status": "CORTEX_STEP_STATUS_DONE",
"plannerResponse": {"response": "OK"},
"metadata": {
"modelUsage": {
"inputTokens": "150",
"outputTokens": "5"
}
}
}]
}
}`
usage := extractUsageFromTrajectory([]byte(resp))
require.NotNil(t, usage)
require.Equal(t, 150, usage["promptTokenCount"])
require.Equal(t, 5, usage["candidatesTokenCount"])
require.Equal(t, 155, usage["totalTokenCount"])
}
// TestSSEChunkFormat verifies the Gemini SSE output format
func TestSSEChunkFormat(t *testing.T) {
chunk := buildGeminiSSEChunk("Hello world")
require.True(t, len(chunk) > 0)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"text":"Hello world"`)
require.Contains(t, chunk, `"role":"model"`)
require.True(t, chunk[len(chunk)-2:] == "\n\n")
// Verify it's valid JSON after stripping "data: " prefix
jsonStr := chunk[len("data: ") : len(chunk)-2]
var parsed map[string]any
err := json.Unmarshal([]byte(jsonStr), &parsed)
require.NoError(t, err)
response := parsed["response"].(map[string]any)
candidates := response["candidates"].([]any)
require.Len(t, candidates, 1)
}
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
func TestSSEFinalChunkFormat(t *testing.T) {
usage := map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150,
}
chunk := buildGeminiSSEFinalChunk(usage)
require.Contains(t, chunk, "data: ")
require.Contains(t, chunk, `"finishReason":"STOP"`)
require.Contains(t, chunk, `"usageMetadata"`)
}
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
var getCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
return
}
http.NotFound(w, r)
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
pr, pw := io.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
close(done)
}()
body, err := io.ReadAll(pr)
require.NoError(t, err)
<-done
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
require.Contains(t, string(body), "hello from ls")
}
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
func TestRequestHasToolsEdgeCases(t *testing.T) {
// null tools
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
// tools with empty function declarations
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
// deeply nested wrapped format
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
}
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
var getCalls atomic.Int32
var sendBodiesMu sync.Mutex
var sendBodies []map[string]any
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
var payload map[string]any
err := json.NewDecoder(r.Body).Decode(&payload)
require.NoError(t, err)
sendBodiesMu.Lock()
sendBodies = append(sendBodies, payload)
sendBodiesMu.Unlock()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
call := getCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
text := "hello from ls"
if call > 1 {
text = "follow up from ls"
}
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Contains(t, string(body2), `"text":"follow up from ls"`)
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
require.Equal(t, int32(2), sendCalls.Load())
sendBodiesMu.Lock()
require.Len(t, sendBodies, 2)
firstSend := sendBodies[0]
sendBodiesMu.Unlock()
require.Equal(t, "cid-1", firstSend["cascadeId"])
require.Equal(t, false, firstSend["blocking"])
metadata, ok := firstSend["metadata"].(map[string]any)
require.True(t, ok)
require.Equal(t, "antigravity", metadata["ideName"])
require.Equal(t, "1.107.0", metadata["ideVersion"])
require.NotContains(t, firstSend, "clientType")
require.NotContains(t, firstSend, "messageOrigin")
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
require.True(t, ok)
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
require.Len(t, cascadeConfig, 1)
}
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
inst.SetModelMappingReady(true)
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
require.NoError(t, err)
req1.Header.Set("Authorization", "Bearer downstream-a")
resp1, err := upstream.Do(req1, "", 42, 1)
require.NoError(t, err)
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Contains(t, string(body1), `"text":"hello from ls"`)
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
require.NoError(t, err)
req2.Header.Set("Authorization", "Bearer downstream-a")
resp2, err := upstream.Do(req2, "", 42, 1)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "ok", string(body2))
require.Equal(t, 1, fallback.doCalls)
require.Equal(t, int32(1), startCalls.Load())
require.Equal(t, int32(1), sendCalls.Load())
}
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
var startCalls atomic.Int32
var sendCalls atomic.Int32
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
startCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
sendCalls.Add(1)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"queued":false}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
inst := &Instance{
AccountID: "42",
CSRF: "test-csrf",
Address: strings.TrimPrefix(server.URL, "https://"),
client: server.Client(),
healthy: true,
lastUsed: time.Now(),
}
fallback := &recordingUpstream{}
pool := &Pool{
config: Config{ReplicasPerAccount: 1},
instances: map[string][]*Instance{"42": []*Instance{inst}},
}
upstream := NewLSPoolUpstream(pool, fallback)
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer downstream-a")
resp, err := upstream.Do(req, "", 42, 1)
require.Nil(t, resp)
require.ErrorIs(t, err, errLSModelMapPending)
require.Equal(t, int32(0), startCalls.Load())
require.Equal(t, int32(0), sendCalls.Load())
require.Equal(t, 0, fallback.doCalls)
}
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
type recordingUpstreamWithCallback struct {
recordingUpstream
onDo func(req *http.Request)
}
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
if r.onDo != nil {
r.onDo(req)
}
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
}

View File

@ -1,920 +0,0 @@
// Package lspool provides a mock Extension Server that the LS binary connects
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
// using connectNodeAdapter. We replicate that protocol here.
//
// Protocol details (from extension.js source):
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
// - Auth: x-codeium-csrf-token header on every request
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
// OR application/connect+proto (with 5-byte envelope)
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
//
// The LS sends requests with content-type "application/connect+proto" for BOTH
// unary and streaming RPCs. ConnectRPC's content-type regex:
//
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
//
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
// However the LS Go client uses the Connect protocol client which always sends
// "application/proto" for unary and "application/connect+proto" for streaming.
//
// We detect the RPC kind from the URL path and respond accordingly.
package lspool
import (
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
// ============================================================
// Proto helpers — hand-encode minimal proto messages so we don't
// need to import the full generated proto package.
// ============================================================
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
func encodeProtoString(fieldNum int, val string) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, []byte(val)...)
return out
}
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
func encodeProtoBytes(fieldNum int, val []byte) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, val...)
return out
}
// encodeProtoVarint writes a proto varint field (wire type 0).
func encodeProtoVarint(fieldNum int, val uint64) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 0))
v := encodeVarint(val)
out := make([]byte, 0, len(tag)+len(v))
out = append(out, tag...)
out = append(out, v...)
return out
}
// encodeProtoBool writes a proto bool field.
func encodeProtoBool(fieldNum int, val bool) []byte {
v := uint64(0)
if val {
v = 1
}
return encodeProtoVarint(fieldNum, v)
}
func encodeVarint(v uint64) []byte {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutUvarint(buf, v)
return buf[:n]
}
// decodeProtoString extracts a string field from raw proto bytes.
func decodeProtoString(data []byte, targetField int) string {
i := 0
for i < len(data) {
if i >= len(data) {
break
}
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
break
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0: // varint
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
case 2: // length-delimited
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
if fieldNum == targetField {
end := i + int(length)
if end > len(data) {
return ""
}
return string(data[i:end])
}
i += int(length)
case 1: // 64-bit
i += 8
case 5: // 32-bit
i += 4
default:
return ""
}
}
return ""
}
// ============================================================
// ConnectRPC envelope helpers
// ============================================================
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
// 1 byte flags + 4 byte big-endian length + payload
func connectEnvelope(flags byte, payload []byte) []byte {
frame := make([]byte, 5+len(payload))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
copy(frame[5:], payload)
return frame
}
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
func connectEndOfStream() []byte {
trailer := []byte("{}")
return connectEnvelope(0x02, trailer)
}
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
func unwrapConnectEnvelope(body []byte) []byte {
if len(body) < 5 {
return body
}
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
if body[0] > 0x02 {
return body // Not envelope-framed, return raw
}
plen := binary.BigEndian.Uint32(body[1:5])
if int(plen)+5 > len(body) {
return body // Length mismatch, return raw
}
return body[5 : 5+plen]
}
// ============================================================
// OAuthTokenInfo proto builder
// ============================================================
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
//
// message OAuthTokenInfo {
// string access_token = 1;
// string token_type = 2;
// string refresh_token = 3;
// google.protobuf.Timestamp expiry = 4;
// bool is_gcp_tos = 6;
// }
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
var buf []byte
buf = append(buf, encodeProtoString(1, accessToken)...)
buf = append(buf, encodeProtoString(2, "Bearer")...)
if refreshToken != "" {
buf = append(buf, encodeProtoString(3, refreshToken)...)
}
// Use real expiry if provided, otherwise default to 1 hour from now
expiry := expiresAt
if expiry.IsZero() {
expiry = time.Now().Add(1 * time.Hour)
}
ts := &timestamppb.Timestamp{
Seconds: expiry.Unix(),
}
tsBytes, _ := proto.Marshal(ts)
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
buf = append(buf, encodeProtoBool(6, true)...)
return buf
}
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
//
// message Topic { map<string, Row> data = 1; }
// message Row { string value = 1; int64 e_tag = 2; }
//
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
// base64(toBinary(OAuthTokenInfo)).
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
var row []byte
row = append(row, encodeProtoString(1, tokenB64)...)
row = append(row, encodeProtoVarint(2, 1)...)
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
var entry []byte
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
entry = append(entry, encodeProtoBytes(2, row)...)
// Topic: data map entries use field 1
var topic []byte
topic = append(topic, encodeProtoBytes(1, entry)...)
return topic
}
func buildPrimitiveBoolBinary(val bool) []byte {
// Primitive.bool_value is field 13 in the proto definition
return encodeProtoBool(13, val)
}
func buildPrimitiveInt32Binary(val int32) []byte {
// Primitive.int32_value is field 3 in the proto definition
return encodeProtoVarint(3, uint64(uint32(val)))
}
func encodeUSSBinaryValue(value []byte) string {
return base64.StdEncoding.EncodeToString(value)
}
func encodeUSSPrimitiveBoolValue(val bool) string {
return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val))
}
func encodeUSSPrimitiveInt32Value(val int32) string {
return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val))
}
func buildUSSTopicRow(key string, value string) []byte {
row := buildUSSRowBinary(value)
var entry []byte
entry = append(entry, encodeProtoString(1, key)...)
entry = append(entry, encodeProtoBytes(2, row)...)
return entry
}
func buildUSSRowBinary(value string) []byte {
var row []byte
row = append(row, encodeProtoString(1, value)...)
row = append(row, encodeProtoVarint(2, 1)...)
return row
}
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
entries := make([][]byte, 0, 3)
entries = append(entries, buildUSSTopicRow(
useAICreditsSentinelKey,
encodeUSSPrimitiveBoolValue(info.UseAICredits),
))
// JS protocol: useAICreditsSentinelKey carries the toggle state.
// availableCreditsSentinelKey is only present when credits are enabled.
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits)))
}
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum)))
var topic []byte
for _, entry := range entries {
topic = append(topic, encodeProtoBytes(1, entry)...)
}
return topic
}
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
func buildEmptyTopic() []byte {
return []byte{} // Empty message = no map entries
}
// ============================================================
// UnifiedStateSyncUpdate builder
// ============================================================
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
//
// message UnifiedStateSyncUpdate {
// oneof update_type {
// Topic initial_state = 1;
// AppliedUpdate applied_update = 2;
// }
// }
func buildInitialStateUpdate(topicData []byte) []byte {
return encodeProtoBytes(1, topicData)
}
func buildAppliedUpdate(key string, row []byte) []byte {
var applied []byte
applied = append(applied, encodeProtoString(1, key)...)
if len(row) > 0 {
applied = append(applied, encodeProtoBytes(2, row)...)
}
return encodeProtoBytes(2, applied)
}
// ============================================================
// MockExtensionServer
// ============================================================
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
// Language Server binary connects to. It implements just enough of the
// ExtensionServerService to keep the LS operational.
type MockExtensionServer struct {
listener net.Listener
server *http.Server
port int
csrf string
mu sync.RWMutex
tokens map[string]*TokenInfo // account_id -> token info
credits map[string]*ModelCreditsInfo // account_id -> model credits info
subscribers map[string]map[int]*stateSubscriber
nextSubID int
lastAccountID string
logger *slog.Logger
// Trajectory callback — when LS pushes trajectory updates, we forward them
onTrajectoryUpdate func(topic, key string, data []byte)
}
// TokenInfo holds OAuth token details for an account.
type TokenInfo struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
}
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
type ModelCreditsInfo struct {
UseAICredits bool
AvailableCredits *int32
MinimumCreditAmountForUsage *int32
}
type stateSubscriber struct {
id int
accountID string
topic string
updates chan []byte
}
const (
useAICreditsSentinelKey = "useAICreditsSentinelKey"
availableCreditsSentinelKey = "availableCreditsSentinelKey"
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
defaultMinimumCreditAmountForUsage = int32(50)
)
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
m := &MockExtensionServer{
listener: listener,
port: listener.Addr().(*net.TCPAddr).Port,
csrf: csrf,
tokens: make(map[string]*TokenInfo),
credits: make(map[string]*ModelCreditsInfo),
subscribers: make(map[string]map[int]*stateSubscriber),
logger: slog.Default().With("component", "mock-ext-server"),
}
mux := http.NewServeMux()
extService := "/exa.extension_server_pb.ExtensionServerService/"
// Register all RPCs the LS calls on the Extension Server.
// Unary RPCs — return application/proto
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
// Server-streaming RPCs — return application/connect+proto
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
// Catch-all for any unregistered RPCs
mux.HandleFunc("/", m.handleCatchAll)
m.server = &http.Server{Handler: mux}
go func() {
if err := m.server.Serve(listener); err != http.ErrServerClosed {
m.logger.Error("extension server error", "err", err)
}
}()
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
return m, nil
}
// Port returns the listening port.
func (m *MockExtensionServer) Port() int {
return m.port
}
// SetToken sets the OAuth token for an account.
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
m.mu.Lock()
m.tokens[accountID] = info
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
m.mu.Unlock()
if info == nil {
return
}
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
}
// SetModelCredits sets the uss-modelCredits state for an account.
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
if info == nil {
info = &ModelCreditsInfo{}
}
copyInfo := *info
m.mu.Lock()
m.credits[accountID] = &copyInfo
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
m.mu.Unlock()
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(&copyInfo)...)
}
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
m.onTrajectoryUpdate = fn
}
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
if m.lastAccountID != "" {
if info := m.tokens[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.tokens {
return info
}
return nil
}
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
if m.lastAccountID != "" {
if info := m.credits[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.credits {
return info
}
return nil
}
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
if accountID != "" {
if info := m.tokens[accountID]; info != nil {
return info
}
}
return m.currentTokenLocked()
}
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
if accountID != "" {
if info := m.credits[accountID]; info != nil {
return info
}
}
return m.currentModelCreditsLocked()
}
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
topicSubs := m.subscribers[topic]
if len(topicSubs) == 0 {
return nil
}
out := make([]*stateSubscriber, 0, len(topicSubs))
for _, sub := range topicSubs {
if sub == nil {
continue
}
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
continue
}
out = append(out, sub)
}
return out
}
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
for _, sub := range subscribers {
if sub == nil {
continue
}
for _, update := range updates {
if len(update) == 0 {
continue
}
payload := append([]byte(nil), update...)
select {
case sub.updates <- payload:
default:
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
}
}
}
}
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
updates := make([][]byte, 0, 3)
updates = append(updates, buildAppliedUpdate(
useAICreditsSentinelKey,
buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)),
))
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
updates = append(updates, buildAppliedUpdate(
availableCreditsSentinelKey,
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)),
))
} else {
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
}
updates = append(updates, buildAppliedUpdate(
minimumCreditAmountForUsageKey,
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)),
))
return updates
}
// Close shuts down the server.
func (m *MockExtensionServer) Close() error {
return m.server.Close()
}
// ============================================================
// Middleware
// ============================================================
type unaryHandler func(body []byte) []byte
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// CSRF check
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
return
}
// The LS might send with envelope framing (application/connect+proto)
// or without (application/proto). Detect and unwrap.
ct := r.Header.Get("Content-Type")
protoBody := body
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
protoBody = unwrapConnectEnvelope(body)
}
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
responseProto := handler(protoBody)
// Respond with proper unary ConnectRPC content-type.
// If the request used "connect+proto", the response should be "application/proto"
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
if len(responseProto) > 0 {
w.Write(responseProto)
}
}
}
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
return
}
// Unwrap envelope from request
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
body = unwrapConnectEnvelope(body)
}
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
// Set streaming response content-type
w.Header().Set("Content-Type", "application/connect+proto")
w.WriteHeader(200)
handler(body, w, r)
}
}
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
token := r.Header.Get("x-codeium-csrf-token")
if m.csrf != "" && token != m.csrf {
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(403)
w.Write([]byte("Invalid CSRF token"))
return false
}
return true
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// ============================================================
// Unary RPC Handlers — each receives raw proto request body,
// returns raw proto response body.
// ============================================================
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
// We just log the ports — they're informational.
m.logger.Info("LanguageServerStarted",
"body_len", len(body))
// Return empty LanguageServerStartedResponse
return nil
}
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
// Return empty HeartbeatResponse
return nil
}
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
// GetSecretValueRequest: key = field 1
key := decodeProtoString(body, 1)
m.logger.Debug("GetSecretValue", "key", key)
m.mu.RLock()
var token string
if info := m.currentTokenLocked(); info != nil {
token = info.AccessToken
}
m.mu.RUnlock()
// GetSecretValueResponse: value = field 1
if token != "" {
return encodeProtoString(1, token)
}
return nil
}
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
key := decodeProtoString(body, 1)
m.logger.Debug("StoreSecretValue", "key", key)
return nil
}
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
return encodeProtoBool(1, false)
}
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
// Extract topic name from the embedded UpdateRequest
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
// We need to dig into the nested message to get topic_name
if m.onTrajectoryUpdate != nil {
// For now, just notify that an update was pushed
m.onTrajectoryUpdate("", "", body)
}
// Return empty PushUnifiedStateSyncUpdateResponse
return nil
}
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
m.logger.Debug("RecordError", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
return nil
}
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onDefault(body []byte) []byte {
return nil
}
// ============================================================
// Streaming RPC Handlers
// ============================================================
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
topic := decodeProtoString(body, 1)
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
flusher, ok := w.(http.Flusher)
if !ok {
m.logger.Error("ResponseWriter does not support Flush")
return
}
m.mu.Lock()
accountID := m.lastAccountID
subID := m.nextSubID
m.nextSubID++
sub := &stateSubscriber{
id: subID,
accountID: accountID,
topic: topic,
updates: make(chan []byte, 16),
}
if m.subscribers[topic] == nil {
m.subscribers[topic] = make(map[int]*stateSubscriber)
}
m.subscribers[topic][subID] = sub
// Build initial state based on topic
var topicData []byte
switch topic {
case "uss-oauth":
tokenInfo := m.tokenForAccountLocked(accountID)
if tokenInfo != nil {
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
} else {
topicData = buildEmptyTopic()
}
case "uss-modelCredits":
creditsInfo := m.creditsForAccountLocked(accountID)
if creditsInfo != nil {
topicData = buildUSSTopicWithModelCredits(creditsInfo)
} else {
topicData = buildEmptyTopic()
}
default:
// For all other topics (browserPreferences, enterprisePreferences, etc.),
// return empty topic data.
topicData = buildEmptyTopic()
}
m.mu.Unlock()
defer func() {
m.mu.Lock()
if topicSubs := m.subscribers[topic]; topicSubs != nil {
delete(topicSubs, subID)
if len(topicSubs) == 0 {
delete(m.subscribers, topic)
}
}
m.mu.Unlock()
}()
// Send initial state as envelope-framed message
initialUpdate := buildInitialStateUpdate(topicData)
frame := connectEnvelope(0x00, initialUpdate)
w.Write(frame)
flusher.Flush()
for {
select {
case <-r.Context().Done():
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
return
case update := <-sub.updates:
if len(update) == 0 {
continue
}
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
return
}
flusher.Flush()
}
}
}
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
// Send end-of-stream immediately — we don't execute commands
flusher, ok := w.(http.Flusher)
if !ok {
return
}
w.Write(connectEndOfStream())
flusher.Flush()
}
// ============================================================
// Catch-all handler
// ============================================================
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
// Drain request body
io.ReadAll(r.Body)
// Determine if this is likely a unary or streaming request based on content-type.
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+") {
// Could be streaming — respond with unary proto to be safe
// (unary Connect requests can also use connect+ prefix in some client impls)
w.Header().Set("Content-Type", "application/proto")
} else {
w.Header().Set("Content-Type", "application/proto")
}
w.WriteHeader(200)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,376 +0,0 @@
package lspool
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
env := buildLSEnv([]string{
"SSL_CERT_FILE=/custom/ca.pem",
"SSL_CERT_DIR=/custom/certs",
}, "/opt/antigravity", "")
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
}
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
env := buildLSEnv([]string{
"HTTPS_PROXY=http://old-proxy:8080",
"HTTP_PROXY=http://old-proxy:8080",
"ALL_PROXY=socks5://old-proxy:1080",
"https_proxy=http://old-proxy:8080",
"http_proxy=http://old-proxy:8080",
"all_proxy=socks5://old-proxy:1080",
}, "/opt/antigravity", "")
require.Contains(t, env, "HTTPS_PROXY=")
require.Contains(t, env, "HTTP_PROXY=")
require.Contains(t, env, "ALL_PROXY=")
require.Contains(t, env, "https_proxy=")
require.Contains(t, env, "http_proxy=")
require.Contains(t, env, "all_proxy=")
}
func TestShortAccountID(t *testing.T) {
require.Equal(t, "9", shortAccountID("9"))
require.Equal(t, "12345678", shortAccountID("12345678"))
require.Equal(t, "12345678", shortAccountID("123456789"))
}
func TestFrameConnectMessage(t *testing.T) {
framed := frameConnectMessage([]byte(`{"x":1}`))
require.Len(t, framed, 5+len(`{"x":1}`))
require.Equal(t, byte(0), framed[0])
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
require.Equal(t, `{"x":1}`, string(framed[5:]))
}
func TestConnectEnvelope(t *testing.T) {
payload := []byte("hello")
env := connectEnvelope(0x00, payload)
require.Len(t, env, 5+len(payload))
require.Equal(t, byte(0x00), env[0])
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
require.Equal(t, "hello", string(env[5:]))
}
func TestUnwrapConnectEnvelope(t *testing.T) {
payload := []byte("test data")
env := connectEnvelope(0x00, payload)
unwrapped := unwrapConnectEnvelope(env)
require.Equal(t, payload, unwrapped)
short := []byte{1, 2}
require.Equal(t, short, unwrapConnectEnvelope(short))
}
func TestExtractPromptAndModel(t *testing.T) {
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
prompt, model := extractPromptAndModel([]byte(body))
require.Equal(t, "hello world", prompt)
require.Equal(t, "gemini-2.5-pro", model)
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
prompt2, _ := extractPromptAndModel([]byte(body2))
require.Equal(t, "test prompt", prompt2)
}
func TestResolveModelEnum(t *testing.T) {
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
require.True(t, resolveModelEnum("unknown-model") > 0)
}
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("models/gemini-2.5-flash")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
cfg := buildCascadeConfig("claude-sonnet-4-6")
require.NotNil(t, cfg)
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
require.True(t, ok)
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
require.True(t, ok)
require.NotEmpty(t, requestedModel["model"])
require.Len(t, plannerConfig, 1)
}
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
fallback := &recordingUpstream{}
upstream := NewLSPoolUpstream(&Pool{}, fallback)
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
resp, err := upstream.Do(req, "", 1, 1)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, 1, fallback.doCalls)
}
func TestExtractPlannerResponseText(t *testing.T) {
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
"plannerResponse":{"response":"Hello world"}}
]}}`
text, generating, status := extractPlannerResponseText([]byte(resp))
require.Equal(t, "Hello world", text)
require.False(t, generating)
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
}
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
resp := `{
"status":"CASCADE_RUN_STATUS_IDLE",
"trajectory":{
"steps":[
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
],
"executorMetadata":{
"terminationReason":"ERROR",
"errorDetails":{
"errorCode":429,
"shortError":"Model quota reached",
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
}
}
}
}`
state := extractPlannerResponseState([]byte(resp))
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
require.False(t, state.Generating)
require.Empty(t, state.Text)
require.Contains(t, state.ErrorMessage, "Model quota reached")
require.Contains(t, state.ErrorMessage, "quota will reset after")
}
func TestBuildGeminiSSEChunk(t *testing.T) {
sse := buildGeminiSSEChunk("hello")
require.Contains(t, sse, "data: ")
require.Contains(t, sse, `"text":"hello"`)
require.Contains(t, sse, `"role":"model"`)
require.True(t, strings.HasSuffix(sse, "\n\n"))
}
func TestRequestHasTools(t *testing.T) {
// Wrapped format with tools
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
// Direct format with tools
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
// No tools
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
// Empty tools array
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
}
func TestCurrentLSStrategy(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
}
func TestIsPermanentModelMappingError(t *testing.T) {
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
}
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
pool := &Pool{
instances: map[string][]*Instance{
"9": {
{AccountID: "9", Replica: 0},
},
},
}
inst := pool.instances["9"][0]
inst.SetModelMappingReady(true)
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
require.False(t, inst.HasModelMappingReady())
require.False(t, inst.HasModelMappingUnavailable())
require.Empty(t, inst.ModelMappingUnavailableReason())
}
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
require.False(t, shouldFallbackDirect(errLSModelMapPending))
}
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
require.Equal(t, 5, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
require.Equal(t, 3, parseLSReplicaCount())
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
require.Equal(t, 5, parseLSReplicaCount())
}
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {
{AccountID: "acc-1", Replica: 0, healthy: true},
{AccountID: "acc-1", Replica: 1, healthy: true},
{AccountID: "acc-1", Replica: 2, healthy: true},
{AccountID: "acc-1", Replica: 3, healthy: true},
{AccountID: "acc-1", Replica: 4, healthy: true},
},
},
}
routingKey := "acc-1:user-a:session-1"
slot := replicaSlotIndex(routingKey, pool.replicaCount())
inst := pool.Get("acc-1", routingKey)
require.NotNil(t, inst)
require.Equal(t, slot, inst.Replica)
}
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
atomic.StoreInt64(&busy.inflight, 4)
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
atomic.StoreInt64(&idle.inflight, 1)
pool := &Pool{
config: Config{ReplicasPerAccount: 5},
instances: map[string][]*Instance{
"acc-1": {busy, idle},
},
}
inst := pool.Get("acc-1", "")
require.NotNil(t, inst)
require.Equal(t, 1, inst.Replica)
}
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
return nil
})
require.NoError(t, err)
require.Equal(t, 1, attempts)
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
}
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
calls := 0
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
calls++
if calls < 3 {
return errors.New("not ready")
}
return nil
})
require.NoError(t, err)
require.Equal(t, 3, attempts)
require.Equal(t, 3, calls)
}
func TestDecideJSParityRoute(t *testing.T) {
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsed, err := parseGeminiRequest(body)
require.NoError(t, err)
decision := decideJSParityRoute(parsed, body)
require.True(t, decision.UseLS)
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
parsedImage, err := parseGeminiRequest(imageBody)
require.NoError(t, err)
decisionImage := decideJSParityRoute(parsedImage, imageBody)
require.False(t, decisionImage.UseLS)
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
parsedNoSession, err := parseGeminiRequest(noSessionBody)
require.NoError(t, err)
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
}
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
require.NoError(t, err)
req.Header.Set(userNamespaceHeader, "tenant-a")
req.Header.Set("Authorization", "Bearer oauth-token")
nsWithExplicit := userNamespace(req)
require.NotEqual(t, "anon", nsWithExplicit)
req.Header.Del(userNamespaceHeader)
nsWithAuth := userNamespace(req)
require.NotEqual(t, "anon", nsWithAuth)
require.NotEqual(t, nsWithExplicit, nsWithAuth)
}
func TestConversationPrefixEqual(t *testing.T) {
prefix := []geminiConversationTurn{
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
}
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
Role: "user",
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
})
require.True(t, conversationPrefixEqual(full, prefix))
require.False(t, conversationPrefixEqual(prefix, full))
}
func TestSystemTextCompatible(t *testing.T) {
require.True(t, systemTextCompatible("You are helpful", ""))
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
require.False(t, systemTextCompatible("", "You are helpful"))
require.False(t, systemTextCompatible("You are helpful", "You are different"))
}
type recordingUpstream struct {
doCalls int
}
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
r.doCalls++
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
}
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
return r.Do(req, proxyURL, accountID, c)
}

View File

@ -1,268 +0,0 @@
package lspool
import (
"bufio"
"context"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"golang.org/x/net/proxy"
)
type lsProxyBridge struct {
listener net.Listener
server *http.Server
url string
upstream string
}
type lsProxyBridgeManager struct {
mu sync.Mutex
bridges map[string]*lsProxyBridge
logger *slog.Logger
}
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
bridges: make(map[string]*lsProxyBridge),
logger: slog.Default().With("component", "lspool-proxy-bridge"),
}
var (
lsProxyBridgeDialTimeout = 10 * time.Second
lsProxyBridgeProbeTargets = []string{
"cloudcode-pa.googleapis.com:443",
"oauthaccountmanager.googleapis.com:443",
}
)
func prepareLSProxyURL(raw string) (string, error) {
normalized, parsed, err := proxyurl.Parse(raw)
if err != nil {
return "", err
}
if parsed == nil {
return "", nil
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
return normalized, nil
case "socks5", "socks5h":
return globalLSProxyBridgeManager.ensure(normalized, parsed)
default:
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
}
}
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if bridge := m.bridges[key]; bridge != nil {
return bridge.url, nil
}
bridge, err := newLSProxyBridge(upstream, m.logger)
if err != nil {
return "", err
}
m.bridges[key] = bridge
return bridge.url, nil
}
func (m *lsProxyBridgeManager) closeAll() {
m.mu.Lock()
defer m.mu.Unlock()
for key, bridge := range m.bridges {
if bridge != nil {
_ = bridge.server.Close()
_ = bridge.listener.Close()
}
delete(m.bridges, key)
}
}
func closeAllLSProxyBridgesForTest() {
globalLSProxyBridgeManager.closeAll()
}
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
dialer, err := proxy.FromURL(upstream, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
}
bridge := &lsProxyBridge{
listener: listener,
url: "http://" + listener.Addr().String(),
upstream: upstream.Redacted(),
}
server := &http.Server{
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
ReadHeaderTimeout: 10 * time.Second,
IdleTimeout: 2 * time.Minute,
}
bridge.server = server
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
}
}()
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
go bridge.probeConnectivity(dialer, logger)
return bridge, nil
}
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
return
}
targetAddr := strings.TrimSpace(r.Host)
if targetAddr == "" {
targetAddr = strings.TrimSpace(r.URL.Host)
}
if targetAddr == "" {
http.Error(w, "missing target host", http.StatusBadRequest)
return
}
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
targetAddr = net.JoinHostPort(targetAddr, "443")
}
startedAt := time.Now()
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
defer cancel()
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
if err != nil {
logger.Warn("LS proxy bridge dial failed",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
"err", err)
http.Error(w, "proxy dial failed", http.StatusBadGateway)
return
}
logger.Info("LS proxy bridge CONNECT established",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
hijacker, ok := w.(http.Hijacker)
if !ok {
_ = targetConn.Close()
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
return
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
_ = targetConn.Close()
return
}
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
_ = targetConn.Close()
_ = clientConn.Close()
return
}
if rw != nil && rw.Reader.Buffered() > 0 {
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
_ = targetConn.Close()
_ = clientConn.Close()
return
}
}
tunnelConns(clientConn, targetConn)
}
}
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
return contextDialer.DialContext(ctx, "tcp", targetAddr)
}
type dialResult struct {
conn net.Conn
err error
}
resultCh := make(chan dialResult, 1)
go func() {
conn, err := dialer.Dial("tcp", targetAddr)
resultCh <- dialResult{conn: conn, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-resultCh:
return result.conn, result.err
}
}
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
for _, targetAddr := range lsProxyBridgeProbeTargets {
startedAt := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
conn, err := dialViaProxy(ctx, dialer, targetAddr)
cancel()
if err != nil {
logger.Warn("LS proxy bridge probe failed",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
"err", err)
continue
}
_ = conn.Close()
logger.Info("LS proxy bridge probe succeeded",
"upstream", b.upstream,
"target", targetAddr,
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
}
}
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
var once sync.Once
closeBoth := func() {
_ = clientConn.Close()
_ = targetConn.Close()
}
go func() {
_, _ = io.Copy(targetConn, clientConn)
once.Do(closeBoth)
}()
go func() {
_, _ = io.Copy(clientConn, targetConn)
once.Do(closeBoth)
}()
}
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
}

View File

@ -1,193 +0,0 @@
package lspool
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
t.Cleanup(closeAllLSProxyBridgesForTest)
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
require.NoError(t, err)
require.Equal(t, "http://proxy.example.com:8080", got)
}
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
t.Cleanup(closeAllLSProxyBridgesForTest)
targetAddr, closeTarget := startBridgeEchoServer(t)
defer closeTarget()
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
defer closeSOCKS()
bridgeURL, err := prepareLSProxyURL(socksURL)
require.NoError(t, err)
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
// Same SOCKS upstream should reuse the same local bridge.
reusedURL, err := prepareLSProxyURL(socksURL)
require.NoError(t, err)
require.Equal(t, bridgeURL, reusedURL)
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
conn, err := net.Dial("tcp", bridgeAddr)
require.NoError(t, err)
defer conn.Close()
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
require.NoError(t, err)
reader := bufio.NewReader(conn)
resp, err := readConnectResponse(reader)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
reply := make([]byte, 4)
_, err = io.ReadFull(reader, reply)
require.NoError(t, err)
require.Equal(t, "pong", string(reply))
}
func startBridgeEchoServer(t *testing.T) (string, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 4)
if _, err := io.ReadFull(c, buf); err != nil {
return
}
if string(buf) == "ping" {
_, _ = c.Write([]byte("pong"))
}
}(conn)
}
}()
return ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go handleBridgeSOCKS5Conn(conn)
}
}()
return "socks5://" + ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func handleBridgeSOCKS5Conn(conn net.Conn) {
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
_ = conn.Close()
return
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
_ = conn.Close()
return
}
_, _ = conn.Write([]byte{0x05, 0x00})
reqHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, reqHeader); err != nil {
_ = conn.Close()
return
}
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
_ = conn.Close()
return
}
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
if !ok {
_ = conn.Close()
return
}
portBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, portBuf); err != nil {
_ = conn.Close()
return
}
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
_ = conn.Close()
return
}
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
tunnelConns(conn, targetConn)
}
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
switch atyp {
case 0x01:
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return net.IP(buf).String(), true
case 0x03:
lenBuf := make([]byte, 1)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return "", false
}
buf := make([]byte, int(lenBuf[0]))
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return string(buf), true
case 0x04:
buf := make([]byte, 16)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", false
}
return net.IP(buf).String(), true
default:
return "", false
}
}

View File

@ -1,138 +0,0 @@
package lspool
import (
"fmt"
"net/url"
"os"
"os/exec"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
)
type lsLaunchPlan struct {
cmd *exec.Cmd
effectiveProxyURL string
proxyMode string
cleanup func()
}
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
if err != nil {
return nil, err
}
plan := &lsLaunchPlan{
cmd: exec.Command(binPath, args...),
proxyMode: "direct",
}
if parsed == nil {
return plan, nil
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
plan.effectiveProxyURL = normalized
plan.proxyMode = "env-http-proxy"
return plan, nil
case "socks5", "socks5h":
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
cfgPath, err := writeProxychainsConfig(parsed)
if err != nil {
return nil, err
}
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
plan.proxyMode = "proxychains4"
plan.cleanup = func() {
_ = os.Remove(cfgPath)
}
return plan, nil
}
effectiveProxyURL, err := prepareLSProxyURL(normalized)
if err != nil {
return nil, err
}
plan.effectiveProxyURL = effectiveProxyURL
plan.proxyMode = "http-connect-bridge"
return plan, nil
default:
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
}
}
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
content, err := buildProxychainsConfig(proxyURL)
if err != nil {
return "", err
}
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
if err != nil {
return "", fmt.Errorf("create proxychains config: %w", err)
}
if _, err := file.WriteString(content); err != nil {
_ = file.Close()
_ = os.Remove(file.Name())
return "", fmt.Errorf("write proxychains config: %w", err)
}
if err := file.Close(); err != nil {
_ = os.Remove(file.Name())
return "", fmt.Errorf("close proxychains config: %w", err)
}
return file.Name(), nil
}
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
if proxyURL == nil {
return "", fmt.Errorf("proxy url is nil")
}
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
}
host := strings.TrimSpace(proxyURL.Hostname())
port := strings.TrimSpace(proxyURL.Port())
if host == "" {
return "", fmt.Errorf("proxy host is empty")
}
if port == "" {
port = "1080"
}
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
}
var builder strings.Builder
builder.WriteString("strict_chain\n")
builder.WriteString("proxy_dns\n")
builder.WriteString("remote_dns_subnet 224\n")
builder.WriteString("tcp_connect_time_out 8000\n")
builder.WriteString("tcp_read_time_out 15000\n")
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
builder.WriteString("localnet ::1/128\n")
builder.WriteString("[ProxyList]\n")
builder.WriteString("socks5 ")
builder.WriteString(host)
builder.WriteString(" ")
builder.WriteString(port)
if username != "" {
builder.WriteString(" ")
builder.WriteString(username)
if password != "" {
builder.WriteString(" ")
builder.WriteString(password)
}
}
builder.WriteString("\n")
return builder.String(), nil
}

View File

@ -1,31 +0,0 @@
package lspool
import (
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080")
require.NoError(t, err)
cfg, err := buildProxychainsConfig(proxyURL)
require.NoError(t, err)
require.Contains(t, cfg, "proxy_dns\n")
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
require.Contains(t, cfg, "localnet ::1/128\n")
require.Contains(t, cfg, "[ProxyList]\n")
require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n")
}
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
require.NoError(t, err)
_, err = buildProxychainsConfig(proxyURL)
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "whitespace"))
}

View File

@ -1,99 +0,0 @@
package lspool
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
)
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("X-Worker-Token", i.workerToken)
if mode == "json" {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "application/octet-stream")
}
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("worker rpc read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
}
return respBody, nil
}
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("X-Worker-Token", i.workerToken)
if mode == "json" {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "application/octet-stream")
}
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
}
return resp, nil
}
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
base := url.URL{
Scheme: "http",
Host: i.Address,
Path: path,
}
values := url.Values{}
values.Set("service", service)
values.Set("method", method)
values.Set("mode", mode)
if i.routingKey != "" {
values.Set("routing_key", i.routingKey)
}
base.RawQuery = values.Encode()
return base.String(), nil
}
func marshalWorkerJSONBody(input any) ([]byte, error) {
if input == nil {
return []byte("{}"), nil
}
body, err := json.Marshal(input)
if err != nil {
return nil, err
}
return body, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,680 +0,0 @@
package lspool
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
const (
lsWorkerManagedByLabel = "managed-by"
lsWorkerManagedByValue = "sub2api"
lsWorkerAccountLabel = "account_id"
lsWorkerProxyHashLabel = "proxy_hash"
lsWorkerImageTagLabel = "image_tag"
lsWorkerControlPort = 18081
)
type workerManagerConfig struct {
Image string
Network string
DockerSocket string
IdleTTL time.Duration
MaxActive int
StartupTimeout time.Duration
RequestTimeout time.Duration
}
type dockerClient interface {
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
Close() error
}
type workerManager struct {
cfg workerManagerConfig
docker dockerClient
http *http.Client
mu sync.Mutex
workers map[string]*workerHandle
state map[string]*workerAccountState
ctx context.Context
cancel context.CancelFunc
logger *slog.Logger
}
type workerHandle struct {
Key string
AccountID string
ProxyURL string
ProxyHash string
ContainerID string
Container string
Address string
AuthToken string
LastUsed time.Time
LastStateSHA string
}
type workerAccountState struct {
HasToken bool `json:"has_token"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
HasModelCredits bool `json:"has_model_credits"`
UseAICredits bool `json:"use_ai_credits"`
AvailableCredits *int32 `json:"available_credits,omitempty"`
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
}
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
managerCfg := workerManagerConfig{
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
}
if managerCfg.Image == "" {
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
}
if managerCfg.Network == "" {
managerCfg.Network = "sub2api-network"
}
if managerCfg.DockerSocket == "" {
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
}
if managerCfg.IdleTTL <= 0 {
managerCfg.IdleTTL = 15 * time.Minute
}
if managerCfg.MaxActive < 1 {
managerCfg.MaxActive = 50
}
if managerCfg.StartupTimeout <= 0 {
managerCfg.StartupTimeout = 45 * time.Second
}
if managerCfg.RequestTimeout <= 0 {
managerCfg.RequestTimeout = 60 * time.Second
}
opts := []client.Opt{client.WithAPIVersionNegotiation()}
if managerCfg.DockerSocket != "" {
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
} else {
opts = append(opts, client.FromEnv)
}
dockerClient, err := client.NewClientWithOpts(opts...)
if err != nil {
return nil, fmt.Errorf("create docker client: %w", err)
}
return newWorkerManager(managerCfg, dockerClient)
}
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
ctx, cancel := context.WithCancel(context.Background())
mgr := &workerManager{
cfg: cfg,
docker: docker,
http: &http.Client{
Timeout: cfg.RequestTimeout,
Transport: &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConnsPerHost: 8,
},
},
workers: make(map[string]*workerHandle),
state: make(map[string]*workerAccountState),
ctx: ctx,
cancel: cancel,
logger: slog.Default().With("component", "lspool-worker-manager"),
}
if err := mgr.reconcileManagedContainers(ctx); err != nil {
cancel()
_ = docker.Close()
return nil, err
}
go mgr.cleanupLoop()
return mgr, nil
}
func (m *workerManager) Close() {
m.cancel()
m.mu.Lock()
workers := make([]*workerHandle, 0, len(m.workers))
for _, handle := range m.workers {
workers = append(workers, handle)
}
m.workers = make(map[string]*workerHandle)
m.mu.Unlock()
for _, handle := range workers {
m.removeWorkerContainer(context.Background(), handle)
}
if m.docker != nil {
_ = m.docker.Close()
}
}
func (m *workerManager) Stats() map[string]any {
m.mu.Lock()
defer m.mu.Unlock()
return map[string]any{
"accounts": len(m.state),
"total": len(m.workers),
"active": len(m.workers),
}
}
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasToken = true
state.AccessToken = accessToken
state.RefreshToken = refreshToken
if expiresAt.IsZero() {
state.ExpiresAt = nil
} else {
ts := expiresAt.UTC()
state.ExpiresAt = &ts
}
}
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
m.mu.Lock()
defer m.mu.Unlock()
state := m.ensureStateLocked(accountID)
state.HasModelCredits = true
state.UseAICredits = useAICredits
state.AvailableCredits = cloneInt32Ptr(availableCredits)
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
}
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
rawProxy := ""
if len(proxyURL) > 0 {
rawProxy = proxyURL[0]
}
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
if err != nil {
return nil, err
}
if parsedProxy == nil {
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
}
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
proxyHash := proxyHash(normalizedProxy)
workerKey := buildWorkerKey(accountID, proxyHash)
m.mu.Lock()
state := cloneWorkerAccountState(m.state[accountID])
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
}
handle := m.workers[workerKey]
if handle == nil {
if len(m.workers) >= m.cfg.MaxActive {
m.mu.Unlock()
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
}
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
if err != nil {
m.mu.Unlock()
return nil, err
}
m.workers[workerKey] = handle
}
handle.LastUsed = time.Now()
m.mu.Unlock()
if err := m.waitForWorkerHealthy(handle); err != nil {
return nil, err
}
if err := m.syncWorkerState(handle, state); err != nil {
return nil, err
}
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
return nil, err
}
inst := &Instance{
AccountID: accountID,
Replica: replica,
Address: handle.Address,
client: m.http,
healthy: true,
lastUsed: time.Now(),
modelMapReady: 1,
remote: true,
workerToken: handle.AuthToken,
routingKey: routingKey,
}
return inst, nil
}
func (m *workerManager) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.collectIdleWorkers()
}
}
}
func (m *workerManager) collectIdleWorkers() {
now := time.Now()
var expired []*workerHandle
m.mu.Lock()
for key, handle := range m.workers {
if handle == nil {
delete(m.workers, key)
continue
}
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
continue
}
expired = append(expired, handle)
delete(m.workers, key)
}
m.mu.Unlock()
for _, handle := range expired {
m.removeWorkerContainer(context.Background(), handle)
}
}
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
args := filters.NewArgs()
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
All: true,
Filters: args,
})
if err != nil {
return fmt.Errorf("list managed ls workers: %w", err)
}
for _, summary := range containers {
handle := &workerHandle{
ContainerID: summary.ID,
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
}
if err := m.removeWorkerContainer(ctx, handle); err != nil {
return err
}
}
return nil
}
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
authToken := generateUUID()
proxyHost := parsedProxy.Hostname()
proxyPort := parsedProxy.Port()
if proxyPort == "" {
proxyPort = "1080"
}
proxyUser := parsedProxy.User.Username()
proxyPass, _ := parsedProxy.User.Password()
labels := map[string]string{
lsWorkerManagedByLabel: lsWorkerManagedByValue,
lsWorkerAccountLabel: accountID,
lsWorkerProxyHashLabel: proxyHash,
lsWorkerImageTagLabel: m.cfg.Image,
}
env := []string{
"ANTIGRAVITY_APP_ROOT=/app/ls",
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
}
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
env = append(env, "TZ="+tz)
}
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
Image: m.cfg.Image,
Labels: labels,
Env: env,
}, &container.HostConfig{
CapAdd: []string{"NET_ADMIN"},
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
m.cfg.Network: {},
},
}, nil, containerName)
if err != nil {
return nil, fmt.Errorf("create ls worker container: %w", err)
}
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("start ls worker container: %w", err)
}
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, fmt.Errorf("inspect ls worker container: %w", err)
}
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
if err != nil {
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
return nil, err
}
m.logger.Info("created ls worker",
"account", shortAccountID(accountID),
"container", containerName,
"address", address,
"proxy_hash", proxyHash[:8])
return &workerHandle{
Key: buildWorkerKey(accountID, proxyHash),
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: proxyHash,
ContainerID: createResp.ID,
Container: containerName,
Address: address,
AuthToken: authToken,
LastUsed: time.Now(),
}, nil
}
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
if inspect.NetworkSettings == nil {
return "", fmt.Errorf("ls worker inspect missing network settings")
}
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
for _, endpoint := range inspect.NetworkSettings.Networks {
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
}
}
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
}
func firstContainerName(names []string) string {
if len(names) == 0 {
return ""
}
return names[0]
}
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
_ = resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
}
select {
case <-ctx.Done():
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
defer cancel()
values := url.Values{}
if strings.TrimSpace(routingKey) != "" {
values.Set("routing_key", routingKey)
}
var (
lastStatus int
lastBody string
)
for {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
if err != nil {
return err
}
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err == nil {
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
lastStatus = resp.StatusCode
lastBody = truncate(string(body), 200)
if resp.StatusCode == http.StatusOK {
return nil
}
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
}
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
}
}
select {
case <-ctx.Done():
if lastStatus > 0 || lastBody != "" {
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
}
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
case <-time.After(500 * time.Millisecond):
}
}
}
func shouldWarnWorkerNotReady(status int, body string) bool {
if status == http.StatusServiceUnavailable {
normalized := strings.ToLower(strings.TrimSpace(body))
if strings.Contains(normalized, "model mapping not ready") {
return false
}
}
return true
}
func isWorkerModelMappingUnavailable(status int, body string) bool {
if status != http.StatusServiceUnavailable {
return false
}
normalized := strings.ToLower(strings.TrimSpace(body))
return strings.Contains(normalized, errLSModelMapDenied.Error())
}
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
if state == nil {
return fmt.Errorf("ls worker state is nil")
}
body, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal worker state: %w", err)
}
sum := fmt.Sprintf("%x", sha256.Sum256(body))
if handle.LastStateSHA == sum {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Worker-Token", handle.AuthToken)
resp, err := m.http.Do(req)
if err != nil {
return fmt.Errorf("sync worker state: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
}
handle.LastStateSHA = sum
return nil
}
func workerURL(handle *workerHandle, path string, values url.Values) string {
base := url.URL{
Scheme: "http",
Host: handle.Address,
Path: path,
}
if values != nil {
base.RawQuery = values.Encode()
}
return base.String()
}
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
return nil
}
timeout := 5
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
}
return nil
}
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
state := m.state[accountID]
if state == nil {
state = &workerAccountState{}
m.state[accountID] = state
}
return state
}
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
resolved := resolveLSProxy(proxyURL)
normalized, parsed, err := proxyurl.Parse(resolved)
if err != nil {
return "", nil, err
}
if parsed == nil {
return "", nil, nil
}
switch strings.ToLower(parsed.Scheme) {
case "socks5", "socks5h":
return normalized, parsed, nil
default:
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
}
}
func proxyHash(proxyURL string) string {
if strings.TrimSpace(proxyURL) == "" {
return "direct"
}
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
return fmt.Sprintf("%x", sum[:])
}
func buildWorkerKey(accountID, proxyHash string) string {
return accountID + ":" + proxyHash
}
func cloneInt32Ptr(v *int32) *int32 {
if v == nil {
return nil
}
cp := *v
return &cp
}
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
if state == nil {
return nil
}
cp := *state
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
if state.ExpiresAt != nil {
ts := *state.ExpiresAt
cp.ExpiresAt = &ts
}
return &cp
}

View File

@ -1,335 +0,0 @@
package lspool
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/stretchr/testify/require"
)
type fakeDockerClient struct {
mu sync.Mutex
listResp []container.Summary
listCalls int
createCalls int
startCalls int
stopCalls int
removeCalls int
inspectCalls int
removedIDs []string
createdConfigs []*container.Config
inspectResp container.InspectResponse
}
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.listCalls++
return append([]container.Summary(nil), f.listResp...), nil
}
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.createCalls++
f.createdConfigs = append(f.createdConfigs, cfg)
return container.CreateResponse{ID: "worker-created"}, nil
}
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.startCalls++
return nil
}
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
f.inspectCalls++
return f.inspectResp, nil
}
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.stopCalls++
return nil
}
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removeCalls++
f.removedIDs = append(f.removedIDs, containerID)
return nil
}
func (f *fakeDockerClient) Close() error { return nil }
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
require.Error(t, err)
require.Contains(t, err.Error(), "only supports socks5/socks5h")
}
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
require.NoError(t, err)
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
hash1 := proxyHash(normalized)
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
require.Equal(t, hash1, hash2)
}
func TestWorkerManagerRequiresToken(t *testing.T) {
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 2,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
var mu sync.Mutex
var healthCalls int
var readyCalls int
var stateCalls int
var stateBodies [][]byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/healthz":
mu.Lock()
healthCalls++
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
case "/readyz":
mu.Lock()
readyCalls++
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
case "/account/state":
body, _ := io.ReadAll(r.Body)
mu.Lock()
stateCalls++
stateBodies = append(stateBodies, body)
mu.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 4,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
accountID := "9"
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
hash := proxyHash(proxyURL)
key := buildWorkerKey(accountID, hash)
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
manager.mu.Lock()
manager.workers[key] = &workerHandle{
Key: key,
AccountID: accountID,
ProxyURL: proxyURL,
ProxyHash: hash,
ContainerID: "existing-worker",
Container: "sub2api-ls-9-test",
Address: strings.TrimPrefix(server.URL, "http://"),
AuthToken: "worker-token",
LastUsed: time.Now(),
}
manager.mu.Unlock()
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
require.NoError(t, err)
require.True(t, inst1.remote)
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
require.NoError(t, err)
require.True(t, inst2.remote)
mu.Lock()
defer mu.Unlock()
require.GreaterOrEqual(t, healthCalls, 2)
require.GreaterOrEqual(t, readyCalls, 2)
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
require.Len(t, stateBodies, 1)
var synced workerAccountState
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
require.True(t, synced.HasToken)
require.Equal(t, "ya29.test", synced.AccessToken)
}
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
fakeDocker := &fakeDockerClient{}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 1,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
manager.mu.Lock()
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
manager.mu.Unlock()
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
require.Error(t, err)
require.Contains(t, err.Error(), "limit reached")
require.Equal(t, 0, fakeDocker.createCalls)
}
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
fakeDocker := &fakeDockerClient{
listResp: []container.Summary{
{
ID: "old-worker-1",
Names: []string{"/sub2api-ls-9-deadbeef"},
},
{
ID: "old-worker-2",
Names: []string{"/sub2api-ls-10-beadfeed"},
},
},
}
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 4,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, fakeDocker)
require.NoError(t, err)
defer manager.Close()
require.Equal(t, 1, fakeDocker.listCalls)
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
}
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
fakeDocker := &fakeDockerClient{}
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
require.NoError(t, err)
}
func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) {
require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0"))
require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured"))
require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed"))
}
func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/readyz", r.URL.Path)
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`))
}))
defer server.Close()
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 1,
StartupTimeout: time.Second,
RequestTimeout: time.Second,
}, &fakeDockerClient{})
require.NoError(t, err)
defer manager.Close()
handle := &workerHandle{
Container: "sub2api-ls-test",
Address: strings.TrimPrefix(server.URL, "http://"),
AuthToken: "worker-token",
}
err = manager.waitForWorkerReady(handle, "")
require.Error(t, err)
require.ErrorIs(t, err, errLSModelMapDenied)
}
func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/readyz", r.URL.Path)
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("worker model mapping not ready for replica 0\n"))
}))
defer server.Close()
manager, err := newWorkerManager(workerManagerConfig{
Image: "worker:latest",
Network: "sub2api-network",
DockerSocket: "unix:///var/run/docker.sock",
IdleTTL: time.Minute,
MaxActive: 1,
StartupTimeout: 100 * time.Millisecond,
RequestTimeout: time.Second,
}, &fakeDockerClient{})
require.NoError(t, err)
defer manager.Close()
handle := &workerHandle{
Container: "sub2api-ls-test",
Address: strings.TrimPrefix(server.URL, "http://"),
AuthToken: "worker-token",
}
err = manager.waitForWorkerReady(handle, "")
require.Error(t, err)
require.Contains(t, err.Error(), `last_status=503`)
require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`)
}

View File

@ -1,374 +0,0 @@
package lspool
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
type WorkerServerConfig struct {
AccountID string
AuthToken string
ListenAddr string
AppRoot string
NetworkReadyFile string
MaxIdleTime time.Duration
HealthInterval time.Duration
}
type WorkerServer struct {
cfg WorkerServerConfig
pool *Pool
logger *slog.Logger
mu sync.RWMutex
state workerAccountState
}
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
if strings.TrimSpace(cfg.AccountID) == "" {
return nil, fmt.Errorf("worker account id is required")
}
if strings.TrimSpace(cfg.AuthToken) == "" {
return nil, fmt.Errorf("worker auth token is required")
}
if strings.TrimSpace(cfg.ListenAddr) == "" {
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
}
if strings.TrimSpace(cfg.AppRoot) == "" {
cfg.AppRoot = "/app/ls"
}
if cfg.MaxIdleTime <= 0 {
cfg.MaxIdleTime = 15 * time.Minute
}
if cfg.HealthInterval <= 0 {
cfg.HealthInterval = 30 * time.Second
}
poolCfg := DefaultConfig()
poolCfg.AppRoot = cfg.AppRoot
poolCfg.MaxIdleTime = cfg.MaxIdleTime
poolCfg.HealthCheckInterval = cfg.HealthInterval
return &WorkerServer{
cfg: cfg,
pool: NewPool(poolCfg),
logger: slog.Default().With("component", "lsworker"),
}, nil
}
func NewWorkerServerFromEnv() (*WorkerServer, error) {
maxIdleTime := 15 * time.Minute
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
maxIdleTime = parsed
}
}
healthInterval := 30 * time.Second
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
if parsed, err := time.ParseDuration(raw); err == nil {
healthInterval = parsed
}
}
return NewWorkerServer(WorkerServerConfig{
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
MaxIdleTime: maxIdleTime,
HealthInterval: healthInterval,
})
}
func (s *WorkerServer) Close() {
if s.pool != nil {
s.pool.Close()
}
}
func (s *WorkerServer) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", s.handleHealthz)
mux.HandleFunc("/readyz", s.handleReadyz)
mux.HandleFunc("/account/state", s.handleAccountState)
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
return mux
}
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
}
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
defer r.Body.Close()
var payload workerAccountState
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "invalid account state payload", http.StatusBadRequest)
return
}
s.mu.Lock()
s.state = *cloneWorkerAccountState(&payload)
s.mu.Unlock()
s.applyState()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
if len(body) == 0 {
body = []byte("{}")
}
var respBody []byte
switch mode {
case "json":
var input any
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
case "proto":
respBody, err = inst.CallRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(respBody)
}
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
if !s.authorize(w, r) {
return
}
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
if !ok {
return
}
inst, err := s.ensureReady(r.Context(), routingKey)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "read request body failed", http.StatusBadRequest)
return
}
var resp *http.Response
switch mode {
case "json":
var input any
if len(body) == 0 {
body = []byte("{}")
}
if err := json.Unmarshal(body, &input); err != nil {
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
return
}
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
case "proto":
resp, err = inst.StreamRPC(r.Context(), service, method, body)
default:
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
return true
}
http.Error(w, "unauthorized", http.StatusUnauthorized)
return false
}
func subtleHeaderEqual(left, right string) bool {
if left == "" || right == "" {
return false
}
return left == right
}
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return "", "", "", "", false
}
query := r.URL.Query()
service = strings.TrimSpace(query.Get("service"))
method = strings.TrimSpace(query.Get("method"))
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
routingKey = strings.TrimSpace(query.Get("routing_key"))
if service == "" || method == "" {
http.Error(w, "missing rpc target", http.StatusBadRequest)
return "", "", "", "", false
}
if mode == "" {
mode = "proto"
}
return service, method, mode, routingKey, true
}
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("worker network not ready: %w", err)
}
}
s.applyState()
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
return nil, fmt.Errorf("worker access token not configured")
}
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
if err != nil {
return nil, err
}
if inst.HasModelMappingUnavailable() {
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
}
if inst.HasModelMappingReady() {
return inst, nil
}
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
defer cancel()
_ = modelCtx
if !RefreshModelMapping(inst) {
if inst.HasModelMappingUnavailable() {
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
}
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
}
return inst, nil
}
func (s *WorkerServer) applyState() {
s.mu.RLock()
state := cloneWorkerAccountState(&s.state)
s.mu.RUnlock()
if state == nil {
return
}
if state.HasToken {
expiresAt := time.Time{}
if state.ExpiresAt != nil {
expiresAt = state.ExpiresAt.UTC()
}
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
}
if state.HasModelCredits {
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
}
}
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
return &http.Server{
Addr: listenAddr,
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
}
func workerExitCode(err error) int {
if err == nil {
return 0
}
return 1
}
func parseWorkerControlPort() int {
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
if raw == "" {
return lsWorkerControlPort
}
port, err := strconv.Atoi(raw)
if err != nil || port < 1 {
return lsWorkerControlPort
}
return port
}

View File

@ -1,11 +1,20 @@
package routes
import (
"bytes"
"context"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
const (
anthropicEventLoggingURL = "https://api.anthropic.com/api/event_logging/batch"
eventLoggingForwardTimeout = 8 * time.Second
)
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
func RegisterCommonRoutes(r *gin.Engine) {
// 健康检查
@ -13,8 +22,36 @@ func RegisterCommonRoutes(r *gin.Engine) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
// 删除 baseUrl/gateway 字段防止网关地址暴露(见 FINGERPRINT_SECURITY_REPORT.md §GAP-1/2
// 转发而非丢弃,避免"高流量零遥测"异常被检测。
r.POST("/api/event_logging/batch", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil || len(body) == 0 {
c.Status(http.StatusOK)
return
}
sanitized := sanitizeEventBatch(body)
ctx, cancel := context.WithTimeout(c.Request.Context(), eventLoggingForwardTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicEventLoggingURL, bytes.NewReader(sanitized))
if err != nil {
c.Status(http.StatusOK)
return
}
req.Header.Set("Content-Type", "application/json")
// 透传客户端的 Authorization headerOAuth Bearer token
if auth := c.GetHeader("Authorization"); auth != "" {
req.Header.Set("Authorization", auth)
}
resp, err := http.DefaultClient.Do(req)
if err == nil {
resp.Body.Close()
}
c.Status(http.StatusOK)
})

View File

@ -4,6 +4,9 @@ import (
"crypto/sha256"
"encoding/hex"
"fmt"
"regexp"
"sync"
"time"
"github.com/google/uuid"
"github.com/tidwall/gjson"
@ -12,7 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// Attribution block constants matching real Claude Code 2.1.88.
// Attribution block constants matching real Claude Code 2.1.89.
// Source: src/constants/system.ts + src/utils/fingerprint.ts
const (
// fingerprintSalt must match the hardcoded salt in the real CLI.
@ -81,11 +84,10 @@ func extractFirstUserMessageText(body []byte) string {
// Source: extracted/src/constants/system.ts:73-95
func buildAttributionBlock(cliVersion, fingerprint string) string {
version := cliVersion + "." + fingerprint
// 注意cch 字段由 Bun 的 NATIVE_CLIENT_ATTESTATION 编译时 feature 控制。
// npm 安装版本(非原生二进制)此 feature 为 false所以不包含 cch 字段。
// 只有原生二进制安装Bun 打包)才会有 cch且其值会被 Bun 的 Zig 层替换为真实 hash。
// 我们模拟 npm 安装版本的行为:不包含 cch。
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli;", version)
// 2.1.89 起 cch=00000 出现在所有安装模式(含 npm 版),不再只限于原生二进制。
// 原生二进制由 Bun 的 Zig 层在运行时将 00000 替换为真实 attestation hash
// 普通安装版保持 00000 占位符不变。
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli; cch=00000;", version)
}
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
@ -163,20 +165,89 @@ func injectAttributionBlock(body []byte, cliVersion string) []byte {
}
}
// generateSessionIDForAccount generates a deterministic per-account session UUID
// that remains stable within a process-like timeframe.
// Uses instanceSalt + accountID to ensure uniqueness across sub2api instances.
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
// Use a per-account stable UUID (like real CLI's per-process UUID).
// We use accountID as the base — each account gets a different "session".
seed := fmt.Sprintf("session:%s:%d", instanceSalt, accountID)
hash := sha256.Sum256([]byte(seed))
sessionUUID, err := uuid.FromBytes(hash[:16])
if err != nil {
return uuid.New().String()
}
// Set UUID v4 variant
sessionUUID[6] = (sessionUUID[6] & 0x0f) | 0x40
sessionUUID[8] = (sessionUUID[8] & 0x3f) | 0x80
return sessionUUID.String()
// cliSessionEntry holds a cached session UUID with an expiration time.
type cliSessionEntry struct {
id string
expiresAt time.Time
}
// cliSessionCache stores per-account session UUIDs that rotate on a TTL.
// Real CLI creates a new random UUID per process invocation; we approximate
// this by rotating every 30-60 minutes (jittered per account).
var (
cliSessionCache = make(map[int64]cliSessionEntry)
cliSessionCacheMu sync.Mutex
)
// sessionTTLBase is the base TTL for session ID rotation.
const sessionTTLBase = 30 * time.Minute
// generateSessionIDForAccount returns a per-account session UUID that rotates
// periodically. Each account gets a random TTL jitter (0-30 min on top of
// the 30 min base) so accounts don't all rotate simultaneously.
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
cliSessionCacheMu.Lock()
defer cliSessionCacheMu.Unlock()
now := time.Now()
if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) {
return entry.id
}
// Compute per-account jitter from a hash so the same account always gets
// the same jitter within a process (avoids re-rolling on every rotation).
jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID)
h := sha256.Sum256([]byte(jitterSeed))
jitterMinutes := int(h[0]) % 31 // 0-30 minutes
ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute
newID := uuid.New().String()
cliSessionCache[accountID] = cliSessionEntry{
id: newID,
expiresAt: now.Add(ttl),
}
return newID
}
// reUserHome matches /Users/<username>/ or /home/<username>/ path segments.
// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username.
var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`)
// reEnvLine matches lines of the form "Key: value" for the environment block
// fields injected by Claude Code's CLAUDE.md / sysprompt machinery.
var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`)
// canonicalEnvValues maps environment block keys to their canonical replacements.
// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine.
var canonicalEnvValues = map[string]string{
"Platform": "Platform: darwin",
"Shell": "Shell: zsh",
"OS Version": "OS Version: Darwin 24.4.0",
"Working directory": "Working directory: /Users/user/project",
}
// NormalizeSystemPromptEnv rewrites environment-specific fields in a system
// prompt text block to canonical values, preventing real machine fingerprinting.
//
// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText):
// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0
// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/"
//
// Only called on system prompt text blocks, never on user message content.
func NormalizeSystemPromptEnv(text string) string {
// Replace env-info lines with canonical values
text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string {
for key, canonical := range canonicalEnvValues {
if len(line) >= len(key) && line[:len(key)] == key {
return canonical
}
}
return line
})
// Redact real usernames in home directory paths
// e.g. /Users/alice/project -> /Users/user/project
text = reUserHome.ReplaceAllString(text, "${1}user/")
return text
}

View File

@ -895,6 +895,9 @@ func sanitizeSystemText(text string) string {
"You are OpenCode, the best coding agent on the planet.",
strings.TrimSpace(claudeCodeSystemPrompt),
)
// Normalize environment block fields (Platform/Shell/OS Version/Working directory)
// to canonical values so different client machines don't create fingerprint divergence.
text = NormalizeSystemPromptEnv(text)
return text
}
@ -5773,7 +5776,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.HaikuBetaHeader
}
return claude.DefaultBetaHeader
return claude.GetOAuthBetaHeader(modelID)
}
func requestNeedsBetaFeatures(body []byte) bool {
@ -5790,10 +5793,7 @@ func requestNeedsBetaFeatures(body []byte) bool {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
return claude.GetAPIKeyBetaHeader(modelID)
}
func applyClaudeOAuthHeaderDefaults(req *http.Request) {

View File

@ -26,7 +26,7 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.1.88 (external, cli)",
UserAgent: "claude-cli/2.1.89 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.74.0",
StainlessOS: "MacOS",

View File

@ -1,225 +0,0 @@
package service
import (
"context"
"fmt"
"log/slog"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
)
const (
defaultLSPoolBootstrapConcurrency = 4
)
type lsBootstrapAccountReader interface {
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
}
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
type LSPoolBootstrapService struct {
accountReader lsBootstrapAccountReader
backend lspool.Backend
cfg *config.Config
logger *slog.Logger
ctx context.Context
cancel context.CancelFunc
once sync.Once
wg sync.WaitGroup
}
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
ctx, cancel := context.WithCancel(context.Background())
return &LSPoolBootstrapService{
accountReader: accountReader,
backend: backend,
cfg: cfg,
logger: slog.Default().With("component", "service.lspool_bootstrap"),
ctx: ctx,
cancel: cancel,
}
}
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
svc.Start()
return svc
}
func (s *LSPoolBootstrapService) Start() {
if s == nil {
return
}
s.once.Do(func() {
if s.backend == nil {
if lspool.IsLSModeEnabled() {
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
}
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.bootstrap(s.ctx)
}()
})
}
func (s *LSPoolBootstrapService) Stop() {
if s == nil {
return
}
s.cancel()
s.wg.Wait()
}
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
if s.backend == nil || s.accountReader == nil {
return
}
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
if err != nil {
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
return
}
now := time.Now()
candidates := make([]Account, 0, len(accounts))
for i := range accounts {
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
candidates = append(candidates, accounts[i])
}
}
if len(candidates) == 0 {
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
return
}
s.logger.Info("starting ls worker bootstrap",
"accounts_total", len(accounts),
"accounts_eligible", len(candidates),
"concurrency", s.bootstrapConcurrency())
var (
mu sync.Mutex
started int
failed int
)
sem := make(chan struct{}, s.bootstrapConcurrency())
var wg sync.WaitGroup
loop:
for i := range candidates {
account := candidates[i]
select {
case <-ctx.Done():
break loop
case sem <- struct{}{}:
}
wg.Add(1)
go func(account Account) {
defer wg.Done()
defer func() { <-sem }()
if err := s.bootstrapAccount(&account); err != nil {
mu.Lock()
failed++
mu.Unlock()
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
return
}
mu.Lock()
started++
mu.Unlock()
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
}(account)
}
wg.Wait()
s.logger.Info("ls worker bootstrap completed",
"accounts_total", len(accounts),
"accounts_eligible", len(candidates),
"workers_ready", started,
"workers_failed", failed,
"canceled", ctx.Err() != nil)
}
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
if s.backend == nil {
return fmt.Errorf("ls backend unavailable")
}
if account == nil {
return fmt.Errorf("account is nil")
}
accountKey := strconv.FormatInt(account.ID, 10)
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
if accessToken == "" {
return fmt.Errorf("missing access token")
}
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
expiresAt := time.Time{}
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
expiresAt = ts.UTC()
}
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
return fmt.Errorf("get or create ls worker: %w", err)
}
return nil
}
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
parallelism := defaultLSPoolBootstrapConcurrency
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
}
if parallelism < 1 {
return 1
}
return parallelism
}
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
if account == nil {
return false
}
if account.Platform != PlatformAntigravity {
return false
}
if account.Type != AccountTypeOAuth {
return false
}
if account.Status != StatusActive || !account.Schedulable {
return false
}
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
return false
}
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
return false
}
return strings.TrimSpace(account.GetCredential("project_id")) != ""
}

View File

@ -1,262 +0,0 @@
package service
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
"github.com/stretchr/testify/require"
)
type fakeLSBootstrapAccountReader struct {
mu sync.Mutex
accounts []Account
err error
platforms []string
}
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
f.mu.Lock()
f.platforms = append(f.platforms, platform)
accounts := append([]Account(nil), f.accounts...)
err := f.err
f.mu.Unlock()
return accounts, err
}
type fakeLSPoolBackend struct {
mu sync.Mutex
tokenCalls map[string]fakeLSPoolTokenCall
creditCalls map[string]fakeLSPoolCreditCall
getCalls []fakeLSPoolGetCall
getErrs map[string]error
}
type fakeLSPoolTokenCall struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time
}
type fakeLSPoolCreditCall struct {
UseAICredits bool
AvailableCredits *int32
MinimumCreditAmount *int32
}
type fakeLSPoolGetCall struct {
AccountID string
RoutingKey string
ProxyURL string
}
func newFakeLSPoolBackend() *fakeLSPoolBackend {
return &fakeLSPoolBackend{
tokenCalls: make(map[string]fakeLSPoolTokenCall),
creditCalls: make(map[string]fakeLSPoolCreditCall),
getErrs: make(map[string]error),
}
}
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
rawProxy := ""
if len(proxyURL) > 0 {
rawProxy = proxyURL[0]
}
f.mu.Lock()
defer f.mu.Unlock()
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
AccountID: accountID,
RoutingKey: routingKey,
ProxyURL: rawProxy,
})
if err := f.getErrs[accountID]; err != nil {
return nil, err
}
return &lspool.Instance{AccountID: accountID}, nil
}
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
f.mu.Lock()
defer f.mu.Unlock()
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: expiresAt,
}
}
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
f.mu.Lock()
defer f.mu.Unlock()
f.creditCalls[accountID] = fakeLSPoolCreditCall{
UseAICredits: useAICredits,
AvailableCredits: copyInt32Ptr(availableCredits),
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
}
}
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
func (f *fakeLSPoolBackend) Close() {}
func copyInt32Ptr(v *int32) *int32 {
if v == nil {
return nil
}
cp := *v
return &cp
}
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
expiredAt := time.Now().Add(-2 * time.Hour)
reader := &fakeLSBootstrapAccountReader{
accounts: []Account{
{
ID: 101,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
"access_token": "token-101",
"refresh_token": "refresh-101",
"expires_at": expiresAt.Format(time.RFC3339),
"project_id": "proj-101",
},
Extra: map[string]any{
"allow_overages": true,
"ai_credits": []any{
map[string]any{
"credit_type": "GOOGLE_ONE_AI",
"amount": 120,
"minimum_balance": 55,
},
},
},
Proxy: &Proxy{
Protocol: "socks5h",
Host: "127.0.0.1",
Port: 1080,
Username: "alice",
Password: "secret",
},
},
{
ID: 102,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: false,
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
},
{
ID: 103,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"access_token": "token-103"},
},
{
ID: 104,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
AutoPauseOnExpired: true,
ExpiresAt: &expiredAt,
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
},
{
ID: 106,
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
},
{
ID: 105,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"access_token": "token-105"},
},
},
}
backend := newFakeLSPoolBackend()
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
Gateway: config.GatewayConfig{
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
},
})
svc.bootstrap(context.Background())
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
require.Len(t, backend.getCalls, 1)
require.Equal(t, fakeLSPoolGetCall{
AccountID: "101",
RoutingKey: "",
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
}, backend.getCalls[0])
tokenCall, ok := backend.tokenCalls["101"]
require.True(t, ok)
require.Equal(t, "token-101", tokenCall.AccessToken)
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
creditCall, ok := backend.creditCalls["101"]
require.True(t, ok)
require.True(t, creditCall.UseAICredits)
require.NotNil(t, creditCall.AvailableCredits)
require.Equal(t, int32(120), *creditCall.AvailableCredits)
require.NotNil(t, creditCall.MinimumCreditAmount)
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
require.NotContains(t, backend.tokenCalls, "102")
require.NotContains(t, backend.tokenCalls, "103")
require.NotContains(t, backend.tokenCalls, "104")
require.NotContains(t, backend.tokenCalls, "106")
}
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
reader := &fakeLSBootstrapAccountReader{
accounts: []Account{
{
ID: 201,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
},
{
ID: 202,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
},
},
}
backend := newFakeLSPoolBackend()
backend.getErrs["201"] = errors.New("create failed")
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
svc.bootstrap(context.Background())
require.Len(t, backend.getCalls, 2)
require.Contains(t, backend.tokenCalls, "201")
require.Contains(t, backend.tokenCalls, "202")
}

View File

@ -1,21 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
/Q==
-----END CERTIFICATE-----

View File

@ -1,70 +0,0 @@
#!/bin/sh
set -eu
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
PROXY_USER="${LSWORKER_PROXY_USER:-}"
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
if [ -z "${PROXY_HOST}" ]; then
echo "LSWORKER_PROXY_HOST is required" >&2
exit 1
fi
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
if [ -z "${PROXY_IP}" ]; then
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
exit 1
fi
cat >/tmp/redsocks.conf <<EOF
base {
log_debug = off;
log_info = on;
daemon = off;
redirector = iptables;
}
redsocks {
local_ip = 0.0.0.0;
local_port = ${REDSOCKS_PORT};
ip = ${PROXY_IP};
port = ${PROXY_PORT};
type = socks5;
EOF
if [ -n "${PROXY_USER}" ]; then
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
fi
if [ -n "${PROXY_PASS}" ]; then
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
fi
cat >>/tmp/redsocks.conf <<EOF
}
EOF
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
REDSOCKS_PID="$!"
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
sleep 1
iptables -t nat -N REDSOCKS 2>/dev/null || true
iptables -t nat -F REDSOCKS
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
touch "${NETWORK_READY_FILE}"
exec gosu sub2api /app/lsworker

View File

@ -1,52 +0,0 @@
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG DEBIAN_IMAGE=debian:bookworm-slim
FROM ${GOLANG_IMAGE} AS builder
WORKDIR /app/backend
RUN apk add --no-cache git ca-certificates tzdata
COPY backend/go.mod backend/go.sum ./
RUN go mod download
COPY backend/ ./
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
FROM ${DEBIAN_IMAGE}
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
curl \
gosu \
iproute2 \
iptables \
redsocks \
tzdata \
&& rm -rf /var/lib/apt/lists/*
RUN groupadd -g 1000 sub2api && \
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
WORKDIR /app
COPY --from=builder /app/lsworker /app/lsworker
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
ARG TARGETARCH
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
else \
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
fi && \
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
chown -R sub2api:sub2api /app /run/lsworker && \
rm -rf /tmp/ls-bin
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
RUN chmod +x /app/lsworker-entrypoint.sh
EXPOSE 18081
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]