Merge branch 'codex/internal-sync-20260330'
This commit is contained in:
commit
860fc736bf
46
Dockerfile
46
Dockerfile
@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
@ -63,10 +64,12 @@ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||
# Version precedence: build arg VERSION > cmd/server/VERSION
|
||||
ARG TARGETARCH
|
||||
ARG TARGETOS=linux
|
||||
RUN VERSION_VALUE="${VERSION}" && \
|
||||
if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \
|
||||
DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \
|
||||
CGO_ENABLED=0 GOOS=linux go build \
|
||||
CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build \
|
||||
-tags embed \
|
||||
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
|
||||
-trimpath \
|
||||
@ -79,9 +82,9 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# Stage 4: Final Runtime Image (Debian for glibc — LS binary requires it)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
FROM ${DEBIAN_IMAGE}
|
||||
|
||||
# Labels
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
@ -89,27 +92,25 @@ LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
wget \
|
||||
gosu \
|
||||
proxychains4 \
|
||||
tzdata \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
libpq5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
RUN ldconfig
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
RUN groupadd -g 1000 sub2api && \
|
||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
@ -118,6 +119,21 @@ WORKDIR /app
|
||||
COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api
|
||||
COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources
|
||||
|
||||
# Copy Language Server binary and cert (for LS pool mode)
|
||||
# Enable with: ANTIGRAVITY_LS_MODE=true ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
# TARGETARCH is set automatically by buildx (amd64 or arm64)
|
||||
ARG TARGETARCH
|
||||
COPY --chown=sub2api:sub2api deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY --chown=sub2api:sub2api deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin && \
|
||||
if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||
|
||||
|
||||
49
backend/cmd/lsworker/main.go
Normal file
49
backend/cmd/lsworker/main.go
Normal file
@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
func main() {
|
||||
server, err := lspool.NewWorkerServerFromEnv()
|
||||
if err != nil {
|
||||
slog.Error("failed to initialize lsworker", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
|
||||
Handler: server.Handler(),
|
||||
ReadHeaderTimeout: 10 * 1e9,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = httpServer.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
slog.Info("lsworker listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
slog.Error("lsworker exited with error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@ -379,6 +379,8 @@ type GatewayConfig struct {
|
||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
|
||||
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||
// AntigravityLSWorker: LS worker 容器控制平面配置
|
||||
AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@ -481,6 +483,16 @@ type GatewayConfig struct {
|
||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||
}
|
||||
|
||||
type GatewayAntigravityLSWorkerConfig struct {
|
||||
Image string `mapstructure:"image"`
|
||||
Network string `mapstructure:"network"`
|
||||
DockerSocket string `mapstructure:"docker_socket"`
|
||||
IdleTTL time.Duration `mapstructure:"idle_ttl"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
StartupTimeout time.Duration `mapstructure:"startup_timeout"`
|
||||
RequestTimeout time.Duration `mapstructure:"request_timeout"`
|
||||
}
|
||||
|
||||
// UserMessageQueueConfig 用户消息串行队列配置
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
|
||||
type UserMessageQueueConfig struct {
|
||||
@ -1307,6 +1319,15 @@ func setDefaults() {
|
||||
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Gateway LS worker
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.image", "weishaw/sub2api-lsworker:latest")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.network", "sub2api-network")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.docker_socket", "unix:///var/run/docker.sock")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.idle_ttl", 15*time.Minute)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.max_active", 50)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.startup_timeout", 45*time.Second)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.request_timeout", 60*time.Second)
|
||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
|
||||
@ -440,7 +440,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
reqBody.Metadata.IDEVersion = "1.20.6"
|
||||
reqBody.Metadata.IDEVersion = "1.107.0"
|
||||
reqBody.Metadata.IDEName = "antigravity"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -49,8 +50,8 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||
var defaultUserAgentVersion = "1.20.5"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
||||
var defaultUserAgentVersion = "1.107.0"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
@ -66,9 +67,9 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为)
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
|
||||
@ -39,6 +39,34 @@ func generateStableSessionID(contents []GeminiContent) string {
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it.
|
||||
// preferredSessionID wins; otherwise we derive a stable value from the first user turn.
|
||||
func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(preferredSessionID)
|
||||
if sessionID == "" {
|
||||
var req GeminiRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = generateStableSessionID(req.Contents)
|
||||
}
|
||||
if sessionID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
payload["sessionId"] = sessionID
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
type TransformOptions struct {
|
||||
EnableIdentityPatch bool
|
||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||
|
||||
@ -8,6 +8,43 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureGeminiRequestSessionID(t *testing.T) {
|
||||
t.Run("prefers provided session id", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-from-header", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("keeps existing session id", func(t *testing.T) {
|
||||
body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-in-body", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("derives stable fallback from contents", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
first, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
second, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var firstPayload map[string]any
|
||||
var secondPayload map[string]any
|
||||
require.NoError(t, json.Unmarshal(first, &firstPayload))
|
||||
require.NoError(t, json.Unmarshal(second, &secondPayload))
|
||||
require.NotEmpty(t, firstPayload["sessionId"])
|
||||
require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
13
backend/internal/pkg/lspool/backend.go
Normal file
13
backend/internal/pkg/lspool/backend.go
Normal file
@ -0,0 +1,13 @@
|
||||
package lspool
|
||||
|
||||
import "time"
|
||||
|
||||
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
|
||||
// It may be backed by a local in-process Pool or by remote LS workers.
|
||||
type Backend interface {
|
||||
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
|
||||
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
|
||||
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
|
||||
Stats() map[string]any
|
||||
Close()
|
||||
}
|
||||
94
backend/internal/pkg/lspool/global.go
Normal file
94
backend/internal/pkg/lspool/global.go
Normal file
@ -0,0 +1,94 @@
|
||||
// Package lspool provides LS-mode integration for the antigravity gateway.
|
||||
//
|
||||
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
|
||||
// streamGenerateContent are routed through a real Language Server instance
|
||||
// instead of directly to cloudcode-pa. This provides:
|
||||
//
|
||||
// - Authentic TLS fingerprint (Google's own Go binary)
|
||||
// - Real session management and Heartbeat
|
||||
// - Indistinguishable from a real IDE instance
|
||||
//
|
||||
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
|
||||
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
globalBackend Backend
|
||||
globalPoolOnce sync.Once
|
||||
lsModeEnabled bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
|
||||
}
|
||||
|
||||
// IsLSModeEnabled returns whether LS mode is active
|
||||
func IsLSModeEnabled() bool {
|
||||
return lsModeEnabled
|
||||
}
|
||||
|
||||
const (
|
||||
LSStrategyDirect = "direct"
|
||||
LSStrategyJSParity = "js-parity"
|
||||
)
|
||||
|
||||
// CurrentLSStrategy returns the active LS routing strategy.
|
||||
// Unknown values are treated as "direct" for safety.
|
||||
func CurrentLSStrategy() string {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
|
||||
case "", LSStrategyDirect:
|
||||
return LSStrategyDirect
|
||||
case LSStrategyJSParity:
|
||||
return LSStrategyJSParity
|
||||
default:
|
||||
return LSStrategyDirect
|
||||
}
|
||||
}
|
||||
|
||||
// GlobalPool returns the singleton LS pool instance
|
||||
// Creates it on first call if LS mode is enabled
|
||||
func GlobalPool(cfg *config.Config) Backend {
|
||||
if !lsModeEnabled {
|
||||
return nil
|
||||
}
|
||||
globalPoolOnce.Do(func() {
|
||||
manager, err := NewWorkerManagerFromConfig(cfg)
|
||||
if err != nil {
|
||||
slog.Default().Error("failed to initialize LS worker manager", "err", err)
|
||||
return
|
||||
}
|
||||
globalBackend = manager
|
||||
})
|
||||
return globalBackend
|
||||
}
|
||||
|
||||
// Shutdown closes the global pool
|
||||
func Shutdown() {
|
||||
if globalBackend != nil {
|
||||
globalBackend.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// StatusInfo returns the current LS pool status for diagnostics
|
||||
func StatusInfo() map[string]any {
|
||||
info := map[string]any{
|
||||
"ls_mode_enabled": lsModeEnabled,
|
||||
"build": "enhanced",
|
||||
"user_agent": "antigravity/1.107.0",
|
||||
}
|
||||
if lsModeEnabled && globalBackend != nil {
|
||||
stats := globalBackend.Stats()
|
||||
info["pool_total"] = stats["total"]
|
||||
info["pool_active"] = stats["active"]
|
||||
}
|
||||
return info
|
||||
}
|
||||
794
backend/internal/pkg/lspool/integration_test.go
Normal file
794
backend/internal/pkg/lspool/integration_test.go
Normal file
@ -0,0 +1,794 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func readConnectFrame(r io.Reader) ([]byte, error) {
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen := binary.BigEndian.Uint32(header[1:5])
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeProtoBytesField(data []byte, targetField int) []byte {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
switch wireType {
|
||||
case 0:
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
case 2:
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
if i+int(length) > len(data) {
|
||||
return nil
|
||||
}
|
||||
if fieldNum == targetField {
|
||||
return data[i : i+int(length)]
|
||||
}
|
||||
i += int(length)
|
||||
case 1:
|
||||
i += 8
|
||||
case 5:
|
||||
i += 4
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestMockExtensionServerTokenInjection verifies the token injection flow:
|
||||
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
|
||||
func TestMockExtensionServerTokenInjection(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
// 1. Set token for an account
|
||||
srv.SetToken("account-1", &TokenInfo{
|
||||
AccessToken: "ya29.test-access-token",
|
||||
RefreshToken: "1//test-refresh-token",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
|
||||
// 2. Verify token is stored
|
||||
srv.mu.RLock()
|
||||
info, ok := srv.tokens["account-1"]
|
||||
srv.mu.RUnlock()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ya29.test-access-token", info.AccessToken)
|
||||
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
|
||||
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
// The stream handler will block, so run in background and cancel after we confirm connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Read the first envelope frame (initial state)
|
||||
header := make([]byte, 5)
|
||||
n, readErr := resp.Body.Read(header)
|
||||
if readErr == nil && n == 5 {
|
||||
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
|
||||
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMockExtensionServerCSRF verifies CSRF token validation
|
||||
func TestMockExtensionServerCSRF(t *testing.T) {
|
||||
csrf := "correct-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
|
||||
|
||||
// Wrong CSRF → 403
|
||||
req, _ := http.NewRequest("POST", base, nil)
|
||||
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 403, resp.StatusCode)
|
||||
|
||||
// Correct CSRF → 200
|
||||
req2, _ := http.NewRequest("POST", base, nil)
|
||||
req2.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req2.Header.Set("Content-Type", "application/proto")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
require.NoError(t, err)
|
||||
defer resp2.Body.Close()
|
||||
require.Equal(t, 200, resp2.StatusCode)
|
||||
}
|
||||
|
||||
// TestMockExtensionServerGetSecretValue verifies the fallback token path
|
||||
func TestMockExtensionServerGetSecretValue(t *testing.T) {
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
|
||||
|
||||
// GetSecretValue should return the token
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
|
||||
nil)
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
|
||||
func TestOAuthTokenInfoProto(t *testing.T) {
|
||||
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
|
||||
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
|
||||
|
||||
// Verify fields are present by checking proto wire format
|
||||
require.True(t, len(bin) > 0, "proto should not be empty")
|
||||
|
||||
// Field 1 (access_token): tag=0x0a, value="ya29.test"
|
||||
require.Contains(t, string(bin), "ya29.test")
|
||||
// Field 2 (token_type): tag=0x12, value="Bearer"
|
||||
require.Contains(t, string(bin), "Bearer")
|
||||
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
|
||||
require.Contains(t, string(bin), "1//refresh")
|
||||
|
||||
// Without refresh_token
|
||||
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
|
||||
require.NotContains(t, string(binNoRefresh), "1//refresh")
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
|
||||
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
|
||||
future := time.Now().Add(2 * time.Hour)
|
||||
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
|
||||
|
||||
// Zero expiry should default to ~1h
|
||||
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
|
||||
|
||||
// They should be different lengths or content (different expiry timestamps)
|
||||
// Both should be valid (non-empty)
|
||||
require.True(t, len(bin) > 0)
|
||||
require.True(t, len(binZero) > 0)
|
||||
}
|
||||
|
||||
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
|
||||
func TestUSSTopicWithOAuth(t *testing.T) {
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
// The topic should contain the sentinel key
|
||||
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
|
||||
}
|
||||
|
||||
func TestUSSTopicWithModelCredits(t *testing.T) {
|
||||
available := int32(123)
|
||||
minimum := int32(50)
|
||||
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
require.Contains(t, string(topic), useAICreditsSentinelKey)
|
||||
require.Contains(t, string(topic), availableCreditsSentinelKey)
|
||||
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
|
||||
}
|
||||
|
||||
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
|
||||
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
// Drain the initial_state frame first.
|
||||
_, err = readConnectFrame(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
keys := make([]string, 0, 3)
|
||||
for len(keys) < 3 {
|
||||
frame, readErr := readConnectFrame(resp.Body)
|
||||
require.NoError(t, readErr)
|
||||
applied := decodeProtoBytesField(frame, 2)
|
||||
require.NotEmpty(t, applied)
|
||||
keys = append(keys, decodeProtoString(applied, 1))
|
||||
}
|
||||
|
||||
require.Contains(t, keys, useAICreditsSentinelKey)
|
||||
require.Contains(t, keys, availableCreditsSentinelKey)
|
||||
require.Contains(t, keys, minimumCreditAmountForUsageKey)
|
||||
}
|
||||
|
||||
// TestBuildInitialStateUpdate verifies the USS update wrapper
|
||||
func TestBuildInitialStateUpdate(t *testing.T) {
|
||||
topicData := buildEmptyTopic()
|
||||
update := buildInitialStateUpdate(topicData)
|
||||
// Should be a valid proto bytes field (field 1 = initial_state)
|
||||
require.True(t, len(update) >= 0) // empty topic is valid
|
||||
|
||||
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
|
||||
update2 := buildInitialStateUpdate(topicData2)
|
||||
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
|
||||
}
|
||||
|
||||
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
|
||||
func TestPoolSetAccountTokenComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.tokens["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "ya29.full-token", info.AccessToken)
|
||||
require.Equal(t, "1//full-refresh", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.credits["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.True(t, info.UseAICredits)
|
||||
require.NotNil(t, info.AvailableCredits)
|
||||
require.Equal(t, available, *info.AvailableCredits)
|
||||
require.NotNil(t, info.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
|
||||
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
|
||||
// Create a mock upstream that records what it receives
|
||||
var receivedHeaders http.Header
|
||||
var mu sync.Mutex
|
||||
fallback := &recordingUpstreamWithCallback{}
|
||||
fallback.onDo = func(req *http.Request) {
|
||||
mu.Lock()
|
||||
receivedHeaders = req.Header.Clone()
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
}
|
||||
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
// Non-streamGenerateContent request → should pass through to fallback
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
|
||||
req.Header.Set("Authorization", "Bearer ya29.test")
|
||||
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
|
||||
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
|
||||
req.Header.Set(useAICreditsHeader, "true")
|
||||
req.Header.Set(availableCreditsHeader, "42")
|
||||
req.Header.Set(minimumCreditAmountHeader, "50")
|
||||
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Internal headers should never leak to the direct upstream.
|
||||
mu.Lock()
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
|
||||
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
|
||||
mu.Unlock()
|
||||
|
||||
srv.mu.RLock()
|
||||
tokenInfo := srv.tokens["1"]
|
||||
creditsInfo := srv.credits["1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, tokenInfo)
|
||||
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
|
||||
require.NotNil(t, creditsInfo)
|
||||
require.True(t, creditsInfo.UseAICredits)
|
||||
require.NotNil(t, creditsInfo.AvailableCredits)
|
||||
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
|
||||
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
|
||||
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
|
||||
body := `{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"request": {
|
||||
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||
]
|
||||
}
|
||||
}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "claude-sonnet-4-6", model)
|
||||
require.Contains(t, prompt, "You are helpful")
|
||||
require.Contains(t, prompt, "Hello")
|
||||
require.Contains(t, prompt, "Hi there!")
|
||||
require.Contains(t, prompt, "How are you?")
|
||||
}
|
||||
|
||||
// TestExtractUsageFromTrajectory verifies token usage extraction
|
||||
func TestExtractUsageFromTrajectory(t *testing.T) {
|
||||
resp := `{
|
||||
"trajectory": {
|
||||
"steps": [{
|
||||
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
|
||||
"status": "CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse": {"response": "OK"},
|
||||
"metadata": {
|
||||
"modelUsage": {
|
||||
"inputTokens": "150",
|
||||
"outputTokens": "5"
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`
|
||||
usage := extractUsageFromTrajectory([]byte(resp))
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 150, usage["promptTokenCount"])
|
||||
require.Equal(t, 5, usage["candidatesTokenCount"])
|
||||
require.Equal(t, 155, usage["totalTokenCount"])
|
||||
}
|
||||
|
||||
// TestSSEChunkFormat verifies the Gemini SSE output format
|
||||
func TestSSEChunkFormat(t *testing.T) {
|
||||
chunk := buildGeminiSSEChunk("Hello world")
|
||||
require.True(t, len(chunk) > 0)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"text":"Hello world"`)
|
||||
require.Contains(t, chunk, `"role":"model"`)
|
||||
require.True(t, chunk[len(chunk)-2:] == "\n\n")
|
||||
|
||||
// Verify it's valid JSON after stripping "data: " prefix
|
||||
jsonStr := chunk[len("data: ") : len(chunk)-2]
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||
require.NoError(t, err)
|
||||
response := parsed["response"].(map[string]any)
|
||||
candidates := response["candidates"].([]any)
|
||||
require.Len(t, candidates, 1)
|
||||
}
|
||||
|
||||
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
|
||||
func TestSSEFinalChunkFormat(t *testing.T) {
|
||||
usage := map[string]any{
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150,
|
||||
}
|
||||
chunk := buildGeminiSSEFinalChunk(usage)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"finishReason":"STOP"`)
|
||||
require.Contains(t, chunk, `"usageMetadata"`)
|
||||
}
|
||||
|
||||
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
|
||||
var getCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
|
||||
getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(pr)
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
|
||||
require.Contains(t, string(body), "hello from ls")
|
||||
}
|
||||
|
||||
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
|
||||
func TestRequestHasToolsEdgeCases(t *testing.T) {
|
||||
// null tools
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
|
||||
// tools with empty function declarations
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
|
||||
// deeply nested wrapped format
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
|
||||
}
|
||||
|
||||
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
var getCalls atomic.Int32
|
||||
var sendBodiesMu sync.Mutex
|
||||
var sendBodies []map[string]any
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
var payload map[string]any
|
||||
err := json.NewDecoder(r.Body).Decode(&payload)
|
||||
require.NoError(t, err)
|
||||
sendBodiesMu.Lock()
|
||||
sendBodies = append(sendBodies, payload)
|
||||
sendBodiesMu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
call := getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
text := "hello from ls"
|
||||
if call > 1 {
|
||||
text = "follow up from ls"
|
||||
}
|
||||
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body2), `"text":"follow up from ls"`)
|
||||
|
||||
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
|
||||
require.Equal(t, int32(2), sendCalls.Load())
|
||||
|
||||
sendBodiesMu.Lock()
|
||||
require.Len(t, sendBodies, 2)
|
||||
firstSend := sendBodies[0]
|
||||
sendBodiesMu.Unlock()
|
||||
|
||||
require.Equal(t, "cid-1", firstSend["cascadeId"])
|
||||
require.Equal(t, false, firstSend["blocking"])
|
||||
metadata, ok := firstSend["metadata"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "antigravity", metadata["ideName"])
|
||||
require.Equal(t, "1.107.0", metadata["ideVersion"])
|
||||
require.NotContains(t, firstSend, "clientType")
|
||||
require.NotContains(t, firstSend, "messageOrigin")
|
||||
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
require.Len(t, cascadeConfig, 1)
|
||||
}
|
||||
|
||||
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "ok", string(body2))
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
require.Equal(t, int32(1), startCalls.Load())
|
||||
require.Equal(t, int32(1), sendCalls.Load())
|
||||
}
|
||||
|
||||
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp, err := upstream.Do(req, "", 42, 1)
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, errLSModelMapPending)
|
||||
require.Equal(t, int32(0), startCalls.Load())
|
||||
require.Equal(t, int32(0), sendCalls.Load())
|
||||
require.Equal(t, 0, fallback.doCalls)
|
||||
}
|
||||
|
||||
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
|
||||
type recordingUpstreamWithCallback struct {
|
||||
recordingUpstream
|
||||
onDo func(req *http.Request)
|
||||
}
|
||||
|
||||
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
|
||||
if r.onDo != nil {
|
||||
r.onDo(req)
|
||||
}
|
||||
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
908
backend/internal/pkg/lspool/mock_extension_server.go
Normal file
908
backend/internal/pkg/lspool/mock_extension_server.go
Normal file
@ -0,0 +1,908 @@
|
||||
// Package lspool provides a mock Extension Server that the LS binary connects
|
||||
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
|
||||
// using connectNodeAdapter. We replicate that protocol here.
|
||||
//
|
||||
// Protocol details (from extension.js source):
|
||||
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
|
||||
// - Auth: x-codeium-csrf-token header on every request
|
||||
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
|
||||
// OR application/connect+proto (with 5-byte envelope)
|
||||
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
|
||||
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
|
||||
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
|
||||
//
|
||||
// The LS sends requests with content-type "application/connect+proto" for BOTH
|
||||
// unary and streaming RPCs. ConnectRPC's content-type regex:
|
||||
//
|
||||
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
|
||||
//
|
||||
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
|
||||
// However the LS Go client uses the Connect protocol client which always sends
|
||||
// "application/proto" for unary and "application/connect+proto" for streaming.
|
||||
//
|
||||
// We detect the RPC kind from the URL path and respond accordingly.
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Proto helpers — hand-encode minimal proto messages so we don't
|
||||
// need to import the full generated proto package.
|
||||
// ============================================================
|
||||
|
||||
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
|
||||
func encodeProtoString(fieldNum int, val string) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, []byte(val)...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
|
||||
func encodeProtoBytes(fieldNum int, val []byte) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, val...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoVarint writes a proto varint field (wire type 0).
|
||||
func encodeProtoVarint(fieldNum int, val uint64) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 0))
|
||||
v := encodeVarint(val)
|
||||
out := make([]byte, 0, len(tag)+len(v))
|
||||
out = append(out, tag...)
|
||||
out = append(out, v...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBool writes a proto bool field.
|
||||
func encodeProtoBool(fieldNum int, val bool) []byte {
|
||||
v := uint64(0)
|
||||
if val {
|
||||
v = 1
|
||||
}
|
||||
return encodeProtoVarint(fieldNum, v)
|
||||
}
|
||||
|
||||
func encodeVarint(v uint64) []byte {
|
||||
buf := make([]byte, binary.MaxVarintLen64)
|
||||
n := binary.PutUvarint(buf, v)
|
||||
return buf[:n]
|
||||
}
|
||||
|
||||
// decodeProtoString extracts a string field from raw proto bytes.
|
||||
func decodeProtoString(data []byte, targetField int) string {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
if i >= len(data) {
|
||||
break
|
||||
}
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
|
||||
switch wireType {
|
||||
case 0: // varint
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
case 2: // length-delimited
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
if fieldNum == targetField {
|
||||
end := i + int(length)
|
||||
if end > len(data) {
|
||||
return ""
|
||||
}
|
||||
return string(data[i:end])
|
||||
}
|
||||
i += int(length)
|
||||
case 1: // 64-bit
|
||||
i += 8
|
||||
case 5: // 32-bit
|
||||
i += 4
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// ConnectRPC envelope helpers
|
||||
// ============================================================
|
||||
|
||||
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
|
||||
// 1 byte flags + 4 byte big-endian length + payload
|
||||
func connectEnvelope(flags byte, payload []byte) []byte {
|
||||
frame := make([]byte, 5+len(payload))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
||||
copy(frame[5:], payload)
|
||||
return frame
|
||||
}
|
||||
|
||||
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
|
||||
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
|
||||
func connectEndOfStream() []byte {
|
||||
trailer := []byte("{}")
|
||||
return connectEnvelope(0x02, trailer)
|
||||
}
|
||||
|
||||
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
|
||||
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
|
||||
func unwrapConnectEnvelope(body []byte) []byte {
|
||||
if len(body) < 5 {
|
||||
return body
|
||||
}
|
||||
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
|
||||
if body[0] > 0x02 {
|
||||
return body // Not envelope-framed, return raw
|
||||
}
|
||||
plen := binary.BigEndian.Uint32(body[1:5])
|
||||
if int(plen)+5 > len(body) {
|
||||
return body // Length mismatch, return raw
|
||||
}
|
||||
return body[5 : 5+plen]
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OAuthTokenInfo proto builder
|
||||
// ============================================================
|
||||
|
||||
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
|
||||
//
|
||||
// message OAuthTokenInfo {
|
||||
// string access_token = 1;
|
||||
// string token_type = 2;
|
||||
// string refresh_token = 3;
|
||||
// google.protobuf.Timestamp expiry = 4;
|
||||
// bool is_gcp_tos = 6;
|
||||
// }
|
||||
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
var buf []byte
|
||||
buf = append(buf, encodeProtoString(1, accessToken)...)
|
||||
buf = append(buf, encodeProtoString(2, "Bearer")...)
|
||||
if refreshToken != "" {
|
||||
buf = append(buf, encodeProtoString(3, refreshToken)...)
|
||||
}
|
||||
// Use real expiry if provided, otherwise default to 1 hour from now
|
||||
expiry := expiresAt
|
||||
if expiry.IsZero() {
|
||||
expiry = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
ts := ×tamppb.Timestamp{
|
||||
Seconds: expiry.Unix(),
|
||||
}
|
||||
tsBytes, _ := proto.Marshal(ts)
|
||||
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
|
||||
buf = append(buf, encodeProtoBool(6, true)...)
|
||||
return buf
|
||||
}
|
||||
|
||||
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
|
||||
//
|
||||
// message Topic { map<string, Row> data = 1; }
|
||||
// message Row { string value = 1; int64 e_tag = 2; }
|
||||
//
|
||||
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
|
||||
// base64(toBinary(OAuthTokenInfo)).
|
||||
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
|
||||
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, tokenB64)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
|
||||
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
|
||||
// Topic: data map entries use field 1
|
||||
var topic []byte
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
|
||||
return topic
|
||||
}
|
||||
|
||||
func buildPrimitiveBoolBinary(val bool) []byte {
|
||||
// Primitive.bool_value is field 13 in the proto definition
|
||||
return encodeProtoBool(13, val)
|
||||
}
|
||||
|
||||
func buildPrimitiveInt32Binary(val int32) []byte {
|
||||
// Primitive.int32_value is field 3 in the proto definition
|
||||
return encodeProtoVarint(3, uint64(uint32(val)))
|
||||
}
|
||||
|
||||
func buildUSSTopicRow(key string, value string) []byte {
|
||||
row := buildUSSRowBinary(value)
|
||||
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, key)...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
return entry
|
||||
}
|
||||
|
||||
func buildUSSRowBinary(value string) []byte {
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, value)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
return row
|
||||
}
|
||||
|
||||
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
entries := make([][]byte, 0, 3)
|
||||
entries = append(entries, buildUSSTopicRow(
|
||||
useAICreditsSentinelKey,
|
||||
string(buildPrimitiveBoolBinary(info.UseAICredits)),
|
||||
))
|
||||
// JS protocol: useAICreditsSentinelKey carries the toggle state.
|
||||
// availableCreditsSentinelKey is only present when credits are enabled.
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, string(buildPrimitiveInt32Binary(credits))))
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, string(buildPrimitiveInt32Binary(minimum))))
|
||||
|
||||
var topic []byte
|
||||
for _, entry := range entries {
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
}
|
||||
return topic
|
||||
}
|
||||
|
||||
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
|
||||
func buildEmptyTopic() []byte {
|
||||
return []byte{} // Empty message = no map entries
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// UnifiedStateSyncUpdate builder
|
||||
// ============================================================
|
||||
|
||||
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
|
||||
//
|
||||
// message UnifiedStateSyncUpdate {
|
||||
// oneof update_type {
|
||||
// Topic initial_state = 1;
|
||||
// AppliedUpdate applied_update = 2;
|
||||
// }
|
||||
// }
|
||||
func buildInitialStateUpdate(topicData []byte) []byte {
|
||||
return encodeProtoBytes(1, topicData)
|
||||
}
|
||||
|
||||
func buildAppliedUpdate(key string, row []byte) []byte {
|
||||
var applied []byte
|
||||
applied = append(applied, encodeProtoString(1, key)...)
|
||||
if len(row) > 0 {
|
||||
applied = append(applied, encodeProtoBytes(2, row)...)
|
||||
}
|
||||
return encodeProtoBytes(2, applied)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MockExtensionServer
|
||||
// ============================================================
|
||||
|
||||
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
|
||||
// Language Server binary connects to. It implements just enough of the
|
||||
// ExtensionServerService to keep the LS operational.
|
||||
type MockExtensionServer struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
port int
|
||||
csrf string
|
||||
mu sync.RWMutex
|
||||
tokens map[string]*TokenInfo // account_id -> token info
|
||||
credits map[string]*ModelCreditsInfo // account_id -> model credits info
|
||||
subscribers map[string]map[int]*stateSubscriber
|
||||
nextSubID int
|
||||
lastAccountID string
|
||||
logger *slog.Logger
|
||||
|
||||
// Trajectory callback — when LS pushes trajectory updates, we forward them
|
||||
onTrajectoryUpdate func(topic, key string, data []byte)
|
||||
}
|
||||
|
||||
// TokenInfo holds OAuth token details for an account.
|
||||
type TokenInfo struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
|
||||
}
|
||||
|
||||
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
|
||||
type ModelCreditsInfo struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmountForUsage *int32
|
||||
}
|
||||
|
||||
type stateSubscriber struct {
|
||||
id int
|
||||
accountID string
|
||||
topic string
|
||||
updates chan []byte
|
||||
}
|
||||
|
||||
const (
|
||||
useAICreditsSentinelKey = "useAICreditsSentinelKey"
|
||||
availableCreditsSentinelKey = "availableCreditsSentinelKey"
|
||||
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
|
||||
defaultMinimumCreditAmountForUsage = int32(50)
|
||||
)
|
||||
|
||||
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
|
||||
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
m := &MockExtensionServer{
|
||||
listener: listener,
|
||||
port: listener.Addr().(*net.TCPAddr).Port,
|
||||
csrf: csrf,
|
||||
tokens: make(map[string]*TokenInfo),
|
||||
credits: make(map[string]*ModelCreditsInfo),
|
||||
subscribers: make(map[string]map[int]*stateSubscriber),
|
||||
logger: slog.Default().With("component", "mock-ext-server"),
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
extService := "/exa.extension_server_pb.ExtensionServerService/"
|
||||
|
||||
// Register all RPCs the LS calls on the Extension Server.
|
||||
// Unary RPCs — return application/proto
|
||||
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
|
||||
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
|
||||
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
|
||||
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
|
||||
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
|
||||
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
|
||||
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
|
||||
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
|
||||
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
|
||||
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
|
||||
|
||||
// Server-streaming RPCs — return application/connect+proto
|
||||
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
|
||||
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
|
||||
|
||||
// Catch-all for any unregistered RPCs
|
||||
mux.HandleFunc("/", m.handleCatchAll)
|
||||
|
||||
m.server = &http.Server{Handler: mux}
|
||||
|
||||
go func() {
|
||||
if err := m.server.Serve(listener); err != http.ErrServerClosed {
|
||||
m.logger.Error("extension server error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Port returns the listening port.
|
||||
func (m *MockExtensionServer) Port() int {
|
||||
return m.port
|
||||
}
|
||||
|
||||
// SetToken sets the OAuth token for an account.
|
||||
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
|
||||
m.mu.Lock()
|
||||
m.tokens[accountID] = info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
|
||||
}
|
||||
|
||||
// SetModelCredits sets the uss-modelCredits state for an account.
|
||||
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
copyInfo := *info
|
||||
m.mu.Lock()
|
||||
m.credits[accountID] = ©Info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...)
|
||||
}
|
||||
|
||||
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
|
||||
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
|
||||
m.onTrajectoryUpdate = fn
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.tokens[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.tokens {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.credits[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.credits {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
|
||||
if accountID != "" {
|
||||
if info := m.tokens[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentTokenLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
|
||||
if accountID != "" {
|
||||
if info := m.credits[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentModelCreditsLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
|
||||
topicSubs := m.subscribers[topic]
|
||||
if len(topicSubs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*stateSubscriber, 0, len(topicSubs))
|
||||
for _, sub := range topicSubs {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
|
||||
continue
|
||||
}
|
||||
out = append(out, sub)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
|
||||
for _, sub := range subscribers {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
for _, update := range updates {
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
payload := append([]byte(nil), update...)
|
||||
select {
|
||||
case sub.updates <- payload:
|
||||
default:
|
||||
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
updates := make([][]byte, 0, 3)
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
useAICreditsSentinelKey,
|
||||
buildUSSRowBinary(string(buildPrimitiveBoolBinary(info.UseAICredits))),
|
||||
))
|
||||
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
availableCreditsSentinelKey,
|
||||
buildUSSRowBinary(string(buildPrimitiveInt32Binary(credits))),
|
||||
))
|
||||
} else {
|
||||
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
minimumCreditAmountForUsageKey,
|
||||
buildUSSRowBinary(string(buildPrimitiveInt32Binary(minimum))),
|
||||
))
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// Close shuts down the server.
|
||||
func (m *MockExtensionServer) Close() error {
|
||||
return m.server.Close()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Middleware
|
||||
// ============================================================
|
||||
|
||||
type unaryHandler func(body []byte) []byte
|
||||
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
|
||||
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// CSRF check
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
|
||||
// The LS might send with envelope framing (application/connect+proto)
|
||||
// or without (application/proto). Detect and unwrap.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
protoBody := body
|
||||
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
|
||||
protoBody = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
|
||||
|
||||
responseProto := handler(protoBody)
|
||||
|
||||
// Respond with proper unary ConnectRPC content-type.
|
||||
// If the request used "connect+proto", the response should be "application/proto"
|
||||
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
if len(responseProto) > 0 {
|
||||
w.Write(responseProto)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
|
||||
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
return
|
||||
}
|
||||
|
||||
// Unwrap envelope from request
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
|
||||
body = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
|
||||
|
||||
// Set streaming response content-type
|
||||
w.Header().Set("Content-Type", "application/connect+proto")
|
||||
w.WriteHeader(200)
|
||||
|
||||
handler(body, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
|
||||
token := r.Header.Get("x-codeium-csrf-token")
|
||||
if m.csrf != "" && token != m.csrf {
|
||||
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(403)
|
||||
w.Write([]byte("Invalid CSRF token"))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Unary RPC Handlers — each receives raw proto request body,
|
||||
// returns raw proto response body.
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
|
||||
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
|
||||
// We just log the ports — they're informational.
|
||||
m.logger.Info("LanguageServerStarted",
|
||||
"body_len", len(body))
|
||||
// Return empty LanguageServerStartedResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
|
||||
// Return empty HeartbeatResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
|
||||
// GetSecretValueRequest: key = field 1
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("GetSecretValue", "key", key)
|
||||
|
||||
m.mu.RLock()
|
||||
var token string
|
||||
if info := m.currentTokenLocked(); info != nil {
|
||||
token = info.AccessToken
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// GetSecretValueResponse: value = field 1
|
||||
if token != "" {
|
||||
return encodeProtoString(1, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("StoreSecretValue", "key", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
|
||||
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
|
||||
return encodeProtoBool(1, false)
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
|
||||
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
|
||||
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
|
||||
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
|
||||
|
||||
// Extract topic name from the embedded UpdateRequest
|
||||
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
|
||||
// We need to dig into the nested message to get topic_name
|
||||
if m.onTrajectoryUpdate != nil {
|
||||
// For now, just notify that an update was pushed
|
||||
m.onTrajectoryUpdate("", "", body)
|
||||
}
|
||||
|
||||
// Return empty PushUnifiedStateSyncUpdateResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
|
||||
m.logger.Debug("RecordError", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
|
||||
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onDefault(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Streaming RPC Handlers
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
|
||||
topic := decodeProtoString(body, 1)
|
||||
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
m.logger.Error("ResponseWriter does not support Flush")
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
accountID := m.lastAccountID
|
||||
subID := m.nextSubID
|
||||
m.nextSubID++
|
||||
sub := &stateSubscriber{
|
||||
id: subID,
|
||||
accountID: accountID,
|
||||
topic: topic,
|
||||
updates: make(chan []byte, 16),
|
||||
}
|
||||
if m.subscribers[topic] == nil {
|
||||
m.subscribers[topic] = make(map[int]*stateSubscriber)
|
||||
}
|
||||
m.subscribers[topic][subID] = sub
|
||||
|
||||
// Build initial state based on topic
|
||||
var topicData []byte
|
||||
switch topic {
|
||||
case "uss-oauth":
|
||||
tokenInfo := m.tokenForAccountLocked(accountID)
|
||||
if tokenInfo != nil {
|
||||
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
case "uss-modelCredits":
|
||||
creditsInfo := m.creditsForAccountLocked(accountID)
|
||||
if creditsInfo != nil {
|
||||
topicData = buildUSSTopicWithModelCredits(creditsInfo)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
default:
|
||||
// For all other topics (browserPreferences, enterprisePreferences, etc.),
|
||||
// return empty topic data.
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
if topicSubs := m.subscribers[topic]; topicSubs != nil {
|
||||
delete(topicSubs, subID)
|
||||
if len(topicSubs) == 0 {
|
||||
delete(m.subscribers, topic)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Send initial state as envelope-framed message
|
||||
initialUpdate := buildInitialStateUpdate(topicData)
|
||||
frame := connectEnvelope(0x00, initialUpdate)
|
||||
w.Write(frame)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
|
||||
return
|
||||
case update := <-sub.updates:
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
|
||||
// Send end-of-stream immediately — we don't execute commands
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.Write(connectEndOfStream())
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Catch-all handler
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
|
||||
|
||||
// Drain request body
|
||||
io.ReadAll(r.Body)
|
||||
|
||||
// Determine if this is likely a unary or streaming request based on content-type.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+") {
|
||||
// Could be streaming — respond with unary proto to be safe
|
||||
// (unary Connect requests can also use connect+ prefix in some client impls)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
1132
backend/internal/pkg/lspool/pool.go
Normal file
1132
backend/internal/pkg/lspool/pool.go
Normal file
File diff suppressed because it is too large
Load Diff
346
backend/internal/pkg/lspool/pool_test.go
Normal file
346
backend/internal/pkg/lspool/pool_test.go
Normal file
@ -0,0 +1,346 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"SSL_CERT_FILE=/custom/ca.pem",
|
||||
"SSL_CERT_DIR=/custom/certs",
|
||||
}, "/opt/antigravity", "")
|
||||
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
|
||||
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
|
||||
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
|
||||
}
|
||||
|
||||
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"HTTPS_PROXY=http://old-proxy:8080",
|
||||
"HTTP_PROXY=http://old-proxy:8080",
|
||||
"ALL_PROXY=socks5://old-proxy:1080",
|
||||
"https_proxy=http://old-proxy:8080",
|
||||
"http_proxy=http://old-proxy:8080",
|
||||
"all_proxy=socks5://old-proxy:1080",
|
||||
}, "/opt/antigravity", "")
|
||||
|
||||
require.Contains(t, env, "HTTPS_PROXY=")
|
||||
require.Contains(t, env, "HTTP_PROXY=")
|
||||
require.Contains(t, env, "ALL_PROXY=")
|
||||
require.Contains(t, env, "https_proxy=")
|
||||
require.Contains(t, env, "http_proxy=")
|
||||
require.Contains(t, env, "all_proxy=")
|
||||
}
|
||||
|
||||
func TestShortAccountID(t *testing.T) {
|
||||
require.Equal(t, "9", shortAccountID("9"))
|
||||
require.Equal(t, "12345678", shortAccountID("12345678"))
|
||||
require.Equal(t, "12345678", shortAccountID("123456789"))
|
||||
}
|
||||
|
||||
func TestFrameConnectMessage(t *testing.T) {
|
||||
framed := frameConnectMessage([]byte(`{"x":1}`))
|
||||
require.Len(t, framed, 5+len(`{"x":1}`))
|
||||
require.Equal(t, byte(0), framed[0])
|
||||
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
|
||||
require.Equal(t, `{"x":1}`, string(framed[5:]))
|
||||
}
|
||||
|
||||
func TestConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("hello")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
require.Len(t, env, 5+len(payload))
|
||||
require.Equal(t, byte(0x00), env[0])
|
||||
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
|
||||
require.Equal(t, "hello", string(env[5:]))
|
||||
}
|
||||
|
||||
func TestUnwrapConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("test data")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
unwrapped := unwrapConnectEnvelope(env)
|
||||
require.Equal(t, payload, unwrapped)
|
||||
short := []byte{1, 2}
|
||||
require.Equal(t, short, unwrapConnectEnvelope(short))
|
||||
}
|
||||
|
||||
func TestExtractPromptAndModel(t *testing.T) {
|
||||
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "hello world", prompt)
|
||||
require.Equal(t, "gemini-2.5-pro", model)
|
||||
|
||||
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
|
||||
prompt2, _ := extractPromptAndModel([]byte(body2))
|
||||
require.Equal(t, "test prompt", prompt2)
|
||||
}
|
||||
|
||||
func TestResolveModelEnum(t *testing.T) {
|
||||
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
|
||||
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
|
||||
require.True(t, resolveModelEnum("unknown-model") > 0)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("models/gemini-2.5-flash")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("claude-sonnet-4-6")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
|
||||
fallback := &recordingUpstream{}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, fallback)
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseText(t *testing.T) {
|
||||
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
|
||||
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse":{"response":"Hello world"}}
|
||||
]}}`
|
||||
text, generating, status := extractPlannerResponseText([]byte(resp))
|
||||
require.Equal(t, "Hello world", text)
|
||||
require.False(t, generating)
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
|
||||
resp := `{
|
||||
"status":"CASCADE_RUN_STATUS_IDLE",
|
||||
"trajectory":{
|
||||
"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
|
||||
],
|
||||
"executorMetadata":{
|
||||
"terminationReason":"ERROR",
|
||||
"errorDetails":{
|
||||
"errorCode":429,
|
||||
"shortError":"Model quota reached",
|
||||
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
state := extractPlannerResponseState([]byte(resp))
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
|
||||
require.False(t, state.Generating)
|
||||
require.Empty(t, state.Text)
|
||||
require.Contains(t, state.ErrorMessage, "Model quota reached")
|
||||
require.Contains(t, state.ErrorMessage, "quota will reset after")
|
||||
}
|
||||
|
||||
func TestBuildGeminiSSEChunk(t *testing.T) {
|
||||
sse := buildGeminiSSEChunk("hello")
|
||||
require.Contains(t, sse, "data: ")
|
||||
require.Contains(t, sse, `"text":"hello"`)
|
||||
require.Contains(t, sse, `"role":"model"`)
|
||||
require.True(t, strings.HasSuffix(sse, "\n\n"))
|
||||
}
|
||||
|
||||
func TestRequestHasTools(t *testing.T) {
|
||||
// Wrapped format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
|
||||
|
||||
// Direct format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
|
||||
|
||||
// No tools
|
||||
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
|
||||
|
||||
// Empty tools array
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
|
||||
}
|
||||
|
||||
func TestCurrentLSStrategy(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
|
||||
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
|
||||
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
||||
}
|
||||
|
||||
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
|
||||
require.Equal(t, 3, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
}
|
||||
|
||||
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {
|
||||
{AccountID: "acc-1", Replica: 0, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 1, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 2, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 3, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 4, healthy: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routingKey := "acc-1:user-a:session-1"
|
||||
slot := replicaSlotIndex(routingKey, pool.replicaCount())
|
||||
inst := pool.Get("acc-1", routingKey)
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, slot, inst.Replica)
|
||||
}
|
||||
|
||||
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
|
||||
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
|
||||
atomic.StoreInt64(&busy.inflight, 4)
|
||||
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
|
||||
atomic.StoreInt64(&idle.inflight, 1)
|
||||
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {busy, idle},
|
||||
},
|
||||
}
|
||||
|
||||
inst := pool.Get("acc-1", "")
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, 1, inst.Replica)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, attempts)
|
||||
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
calls := 0
|
||||
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return errors.New("not ready")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, attempts)
|
||||
require.Equal(t, 3, calls)
|
||||
}
|
||||
|
||||
func TestDecideJSParityRoute(t *testing.T) {
|
||||
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsed, err := parseGeminiRequest(body)
|
||||
require.NoError(t, err)
|
||||
decision := decideJSParityRoute(parsed, body)
|
||||
require.True(t, decision.UseLS)
|
||||
|
||||
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
|
||||
parsedImage, err := parseGeminiRequest(imageBody)
|
||||
require.NoError(t, err)
|
||||
decisionImage := decideJSParityRoute(parsedImage, imageBody)
|
||||
require.False(t, decisionImage.UseLS)
|
||||
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
|
||||
|
||||
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsedNoSession, err := parseGeminiRequest(noSessionBody)
|
||||
require.NoError(t, err)
|
||||
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
|
||||
}
|
||||
|
||||
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(userNamespaceHeader, "tenant-a")
|
||||
req.Header.Set("Authorization", "Bearer oauth-token")
|
||||
|
||||
nsWithExplicit := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithExplicit)
|
||||
|
||||
req.Header.Del(userNamespaceHeader)
|
||||
nsWithAuth := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithAuth)
|
||||
require.NotEqual(t, nsWithExplicit, nsWithAuth)
|
||||
}
|
||||
|
||||
func TestConversationPrefixEqual(t *testing.T) {
|
||||
prefix := []geminiConversationTurn{
|
||||
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
|
||||
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
|
||||
}
|
||||
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
|
||||
Role: "user",
|
||||
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
|
||||
})
|
||||
require.True(t, conversationPrefixEqual(full, prefix))
|
||||
require.False(t, conversationPrefixEqual(prefix, full))
|
||||
}
|
||||
|
||||
func TestSystemTextCompatible(t *testing.T) {
|
||||
require.True(t, systemTextCompatible("You are helpful", ""))
|
||||
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("You are helpful", "You are different"))
|
||||
}
|
||||
|
||||
type recordingUpstream struct {
|
||||
doCalls int
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
r.doCalls++
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return r.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
268
backend/internal/pkg/lspool/proxy_bridge.go
Normal file
268
backend/internal/pkg/lspool/proxy_bridge.go
Normal file
@ -0,0 +1,268 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type lsProxyBridge struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
url string
|
||||
upstream string
|
||||
}
|
||||
|
||||
type lsProxyBridgeManager struct {
|
||||
mu sync.Mutex
|
||||
bridges map[string]*lsProxyBridge
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
|
||||
bridges: make(map[string]*lsProxyBridge),
|
||||
logger: slog.Default().With("component", "lspool-proxy-bridge"),
|
||||
}
|
||||
|
||||
var (
|
||||
lsProxyBridgeDialTimeout = 10 * time.Second
|
||||
lsProxyBridgeProbeTargets = []string{
|
||||
"cloudcode-pa.googleapis.com:443",
|
||||
"oauthaccountmanager.googleapis.com:443",
|
||||
}
|
||||
)
|
||||
|
||||
func prepareLSProxyURL(raw string) (string, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
return normalized, nil
|
||||
case "socks5", "socks5h":
|
||||
return globalLSProxyBridgeManager.ensure(normalized, parsed)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if bridge := m.bridges[key]; bridge != nil {
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
bridge, err := newLSProxyBridge(upstream, m.logger)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
m.bridges[key] = bridge
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) closeAll() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for key, bridge := range m.bridges {
|
||||
if bridge != nil {
|
||||
_ = bridge.server.Close()
|
||||
_ = bridge.listener.Close()
|
||||
}
|
||||
delete(m.bridges, key)
|
||||
}
|
||||
}
|
||||
|
||||
func closeAllLSProxyBridgesForTest() {
|
||||
globalLSProxyBridgeManager.closeAll()
|
||||
}
|
||||
|
||||
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
|
||||
dialer, err := proxy.FromURL(upstream, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
|
||||
}
|
||||
|
||||
bridge := &lsProxyBridge{
|
||||
listener: listener,
|
||||
url: "http://" + listener.Addr().String(),
|
||||
upstream: upstream.Redacted(),
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
}
|
||||
bridge.server = server
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
|
||||
go bridge.probeConnectivity(dialer, logger)
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
targetAddr := strings.TrimSpace(r.Host)
|
||||
if targetAddr == "" {
|
||||
targetAddr = strings.TrimSpace(r.URL.Host)
|
||||
}
|
||||
if targetAddr == "" {
|
||||
http.Error(w, "missing target host", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
|
||||
targetAddr = net.JoinHostPort(targetAddr, "443")
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge dial failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
http.Error(w, "proxy dial failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
logger.Info("LS proxy bridge CONNECT established",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
_ = targetConn.Close()
|
||||
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
_ = targetConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if rw != nil && rw.Reader.Buffered() > 0 {
|
||||
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tunnelConns(clientConn, targetConn)
|
||||
}
|
||||
}
|
||||
|
||||
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
|
||||
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
return contextDialer.DialContext(ctx, "tcp", targetAddr)
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialer.Dial("tcp", targetAddr)
|
||||
resultCh <- dialResult{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.conn, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
|
||||
for _, targetAddr := range lsProxyBridgeProbeTargets {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
|
||||
conn, err := dialViaProxy(ctx, dialer, targetAddr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge probe failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
continue
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Info("LS proxy bridge probe succeeded",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
|
||||
var once sync.Once
|
||||
closeBoth := func() {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(targetConn, clientConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(clientConn, targetConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
}
|
||||
|
||||
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
|
||||
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
||||
}
|
||||
193
backend/internal/pkg/lspool/proxy_bridge_test.go
Normal file
193
backend/internal/pkg/lspool/proxy_bridge_test.go
Normal file
@ -0,0 +1,193 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "http://proxy.example.com:8080", got)
|
||||
}
|
||||
|
||||
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
targetAddr, closeTarget := startBridgeEchoServer(t)
|
||||
defer closeTarget()
|
||||
|
||||
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
|
||||
defer closeSOCKS()
|
||||
|
||||
bridgeURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
|
||||
|
||||
// Same SOCKS upstream should reuse the same local bridge.
|
||||
reusedURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bridgeURL, reusedURL)
|
||||
|
||||
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
|
||||
conn, err := net.Dial("tcp", bridgeAddr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, err := readConnectResponse(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
_, err = conn.Write([]byte("ping"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reply := make([]byte, 4)
|
||||
_, err = io.ReadFull(reader, reply)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", string(reply))
|
||||
}
|
||||
|
||||
func startBridgeEchoServer(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return
|
||||
}
|
||||
if string(buf) == "ping" {
|
||||
_, _ = c.Write([]byte("pong"))
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleBridgeSOCKS5Conn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return "socks5://" + ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func handleBridgeSOCKS5Conn(conn net.Conn) {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
methods := make([]byte, int(header[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte{0x05, 0x00})
|
||||
|
||||
reqHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
|
||||
if !ok {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
targetConn, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
tunnelConns(conn, targetConn)
|
||||
}
|
||||
|
||||
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
|
||||
switch atyp {
|
||||
case 0x01:
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
case 0x03:
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
buf := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(buf), true
|
||||
case 0x04:
|
||||
buf := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
138
backend/internal/pkg/lspool/proxy_exec.go
Normal file
138
backend/internal/pkg/lspool/proxy_exec.go
Normal file
@ -0,0 +1,138 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
)
|
||||
|
||||
type lsLaunchPlan struct {
|
||||
cmd *exec.Cmd
|
||||
effectiveProxyURL string
|
||||
proxyMode string
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan := &lsLaunchPlan{
|
||||
cmd: exec.Command(binPath, args...),
|
||||
proxyMode: "direct",
|
||||
}
|
||||
|
||||
if parsed == nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
plan.effectiveProxyURL = normalized
|
||||
plan.proxyMode = "env-http-proxy"
|
||||
return plan, nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
|
||||
cfgPath, err := writeProxychainsConfig(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
|
||||
plan.proxyMode = "proxychains4"
|
||||
plan.cleanup = func() {
|
||||
_ = os.Remove(cfgPath)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
effectiveProxyURL, err := prepareLSProxyURL(normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.effectiveProxyURL = effectiveProxyURL
|
||||
plan.proxyMode = "http-connect-bridge"
|
||||
return plan, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
content, err := buildProxychainsConfig(proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create proxychains config: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.WriteString(content); err != nil {
|
||||
_ = file.Close()
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("write proxychains config: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("close proxychains config: %w", err)
|
||||
}
|
||||
|
||||
return file.Name(), nil
|
||||
}
|
||||
|
||||
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
if proxyURL == nil {
|
||||
return "", fmt.Errorf("proxy url is nil")
|
||||
}
|
||||
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
|
||||
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(proxyURL.Hostname())
|
||||
port := strings.TrimSpace(proxyURL.Port())
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("proxy host is empty")
|
||||
}
|
||||
if port == "" {
|
||||
port = "1080"
|
||||
}
|
||||
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
|
||||
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("strict_chain\n")
|
||||
builder.WriteString("proxy_dns\n")
|
||||
builder.WriteString("remote_dns_subnet 224\n")
|
||||
builder.WriteString("tcp_connect_time_out 8000\n")
|
||||
builder.WriteString("tcp_read_time_out 15000\n")
|
||||
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
|
||||
builder.WriteString("localnet ::1/128\n")
|
||||
builder.WriteString("[ProxyList]\n")
|
||||
builder.WriteString("socks5 ")
|
||||
builder.WriteString(host)
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(port)
|
||||
if username != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(username)
|
||||
if password != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(password)
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
return builder.String(), nil
|
||||
}
|
||||
31
backend/internal/pkg/lspool/proxy_exec_test.go
Normal file
31
backend/internal/pkg/lspool/proxy_exec_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://gostuser:fastapipwd@216.167.85.31:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := buildProxychainsConfig(proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, cfg, "proxy_dns\n")
|
||||
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
|
||||
require.Contains(t, cfg, "localnet ::1/128\n")
|
||||
require.Contains(t, cfg, "[ProxyList]\n")
|
||||
require.Contains(t, cfg, "socks5 216.167.85.31 1080 gostuser fastapipwd\n")
|
||||
}
|
||||
|
||||
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = buildProxychainsConfig(proxyURL)
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "whitespace"))
|
||||
}
|
||||
99
backend/internal/pkg/lspool/remote_instance.go
Normal file
99
backend/internal/pkg/lspool/remote_instance.go
Normal file
@ -0,0 +1,99 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: i.Address,
|
||||
Path: path,
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Set("service", service)
|
||||
values.Set("method", method)
|
||||
values.Set("mode", mode)
|
||||
if i.routingKey != "" {
|
||||
values.Set("routing_key", i.routingKey)
|
||||
}
|
||||
base.RawQuery = values.Encode()
|
||||
return base.String(), nil
|
||||
}
|
||||
|
||||
func marshalWorkerJSONBody(input any) ([]byte, error) {
|
||||
if input == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
1620
backend/internal/pkg/lspool/upstream_adapter.go
Normal file
1620
backend/internal/pkg/lspool/upstream_adapter.go
Normal file
File diff suppressed because it is too large
Load Diff
649
backend/internal/pkg/lspool/worker_manager.go
Normal file
649
backend/internal/pkg/lspool/worker_manager.go
Normal file
@ -0,0 +1,649 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
lsWorkerManagedByLabel = "managed-by"
|
||||
lsWorkerManagedByValue = "sub2api"
|
||||
lsWorkerAccountLabel = "account_id"
|
||||
lsWorkerProxyHashLabel = "proxy_hash"
|
||||
lsWorkerImageTagLabel = "image_tag"
|
||||
lsWorkerControlPort = 18081
|
||||
)
|
||||
|
||||
type workerManagerConfig struct {
|
||||
Image string
|
||||
Network string
|
||||
DockerSocket string
|
||||
IdleTTL time.Duration
|
||||
MaxActive int
|
||||
StartupTimeout time.Duration
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
type dockerClient interface {
|
||||
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
|
||||
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
|
||||
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
|
||||
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
|
||||
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
|
||||
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type workerManager struct {
|
||||
cfg workerManagerConfig
|
||||
docker dockerClient
|
||||
http *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
workers map[string]*workerHandle
|
||||
state map[string]*workerAccountState
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type workerHandle struct {
|
||||
Key string
|
||||
AccountID string
|
||||
ProxyURL string
|
||||
ProxyHash string
|
||||
ContainerID string
|
||||
Container string
|
||||
Address string
|
||||
AuthToken string
|
||||
LastUsed time.Time
|
||||
LastStateSHA string
|
||||
}
|
||||
|
||||
type workerAccountState struct {
|
||||
HasToken bool `json:"has_token"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
HasModelCredits bool `json:"has_model_credits"`
|
||||
UseAICredits bool `json:"use_ai_credits"`
|
||||
AvailableCredits *int32 `json:"available_credits,omitempty"`
|
||||
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
|
||||
}
|
||||
|
||||
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
managerCfg := workerManagerConfig{
|
||||
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
|
||||
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
|
||||
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
|
||||
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
|
||||
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
|
||||
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
|
||||
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
|
||||
}
|
||||
|
||||
if managerCfg.Image == "" {
|
||||
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
|
||||
}
|
||||
if managerCfg.Network == "" {
|
||||
managerCfg.Network = "sub2api-network"
|
||||
}
|
||||
if managerCfg.DockerSocket == "" {
|
||||
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
|
||||
}
|
||||
if managerCfg.IdleTTL <= 0 {
|
||||
managerCfg.IdleTTL = 15 * time.Minute
|
||||
}
|
||||
if managerCfg.MaxActive < 1 {
|
||||
managerCfg.MaxActive = 50
|
||||
}
|
||||
if managerCfg.StartupTimeout <= 0 {
|
||||
managerCfg.StartupTimeout = 45 * time.Second
|
||||
}
|
||||
if managerCfg.RequestTimeout <= 0 {
|
||||
managerCfg.RequestTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
opts := []client.Opt{client.WithAPIVersionNegotiation()}
|
||||
if managerCfg.DockerSocket != "" {
|
||||
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
|
||||
} else {
|
||||
opts = append(opts, client.FromEnv)
|
||||
}
|
||||
|
||||
dockerClient, err := client.NewClientWithOpts(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create docker client: %w", err)
|
||||
}
|
||||
|
||||
return newWorkerManager(managerCfg, dockerClient)
|
||||
}
|
||||
|
||||
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
mgr := &workerManager{
|
||||
cfg: cfg,
|
||||
docker: docker,
|
||||
http: &http.Client{
|
||||
Timeout: cfg.RequestTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConnsPerHost: 8,
|
||||
},
|
||||
},
|
||||
workers: make(map[string]*workerHandle),
|
||||
state: make(map[string]*workerAccountState),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: slog.Default().With("component", "lspool-worker-manager"),
|
||||
}
|
||||
if err := mgr.reconcileManagedContainers(ctx); err != nil {
|
||||
cancel()
|
||||
_ = docker.Close()
|
||||
return nil, err
|
||||
}
|
||||
go mgr.cleanupLoop()
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) Close() {
|
||||
m.cancel()
|
||||
|
||||
m.mu.Lock()
|
||||
workers := make([]*workerHandle, 0, len(m.workers))
|
||||
for _, handle := range m.workers {
|
||||
workers = append(workers, handle)
|
||||
}
|
||||
m.workers = make(map[string]*workerHandle)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range workers {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
if m.docker != nil {
|
||||
_ = m.docker.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) Stats() map[string]any {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return map[string]any{
|
||||
"accounts": len(m.state),
|
||||
"total": len(m.workers),
|
||||
"active": len(m.workers),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasToken = true
|
||||
state.AccessToken = accessToken
|
||||
state.RefreshToken = refreshToken
|
||||
if expiresAt.IsZero() {
|
||||
state.ExpiresAt = nil
|
||||
} else {
|
||||
ts := expiresAt.UTC()
|
||||
state.ExpiresAt = &ts
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasModelCredits = true
|
||||
state.UseAICredits = useAICredits
|
||||
state.AvailableCredits = cloneInt32Ptr(availableCredits)
|
||||
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedProxy == nil {
|
||||
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
|
||||
}
|
||||
|
||||
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
|
||||
proxyHash := proxyHash(normalizedProxy)
|
||||
workerKey := buildWorkerKey(accountID, proxyHash)
|
||||
|
||||
m.mu.Lock()
|
||||
state := cloneWorkerAccountState(m.state[accountID])
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
|
||||
}
|
||||
|
||||
handle := m.workers[workerKey]
|
||||
if handle == nil {
|
||||
if len(m.workers) >= m.cfg.MaxActive {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
|
||||
}
|
||||
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
m.workers[workerKey] = handle
|
||||
}
|
||||
handle.LastUsed = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
if err := m.waitForWorkerHealthy(handle); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.syncWorkerState(handle, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: accountID,
|
||||
Replica: replica,
|
||||
Address: handle.Address,
|
||||
client: m.http,
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
modelMapReady: 1,
|
||||
remote: true,
|
||||
workerToken: handle.AuthToken,
|
||||
routingKey: routingKey,
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.collectIdleWorkers()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) collectIdleWorkers() {
|
||||
now := time.Now()
|
||||
var expired []*workerHandle
|
||||
|
||||
m.mu.Lock()
|
||||
for key, handle := range m.workers {
|
||||
if handle == nil {
|
||||
delete(m.workers, key)
|
||||
continue
|
||||
}
|
||||
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
|
||||
continue
|
||||
}
|
||||
expired = append(expired, handle)
|
||||
delete(m.workers, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range expired {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
|
||||
args := filters.NewArgs()
|
||||
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
|
||||
|
||||
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
|
||||
All: true,
|
||||
Filters: args,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list managed ls workers: %w", err)
|
||||
}
|
||||
|
||||
for _, summary := range containers {
|
||||
handle := &workerHandle{
|
||||
ContainerID: summary.ID,
|
||||
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
|
||||
}
|
||||
if err := m.removeWorkerContainer(ctx, handle); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
|
||||
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
|
||||
authToken := generateUUID()
|
||||
|
||||
proxyHost := parsedProxy.Hostname()
|
||||
proxyPort := parsedProxy.Port()
|
||||
if proxyPort == "" {
|
||||
proxyPort = "1080"
|
||||
}
|
||||
proxyUser := parsedProxy.User.Username()
|
||||
proxyPass, _ := parsedProxy.User.Password()
|
||||
|
||||
labels := map[string]string{
|
||||
lsWorkerManagedByLabel: lsWorkerManagedByValue,
|
||||
lsWorkerAccountLabel: accountID,
|
||||
lsWorkerProxyHashLabel: proxyHash,
|
||||
lsWorkerImageTagLabel: m.cfg.Image,
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"ANTIGRAVITY_APP_ROOT=/app/ls",
|
||||
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
|
||||
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
|
||||
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
|
||||
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
|
||||
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
|
||||
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
|
||||
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
|
||||
}
|
||||
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
|
||||
env = append(env, "TZ="+tz)
|
||||
}
|
||||
|
||||
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
|
||||
Image: m.cfg.Image,
|
||||
Labels: labels,
|
||||
Env: env,
|
||||
}, &container.HostConfig{
|
||||
CapAdd: []string{"NET_ADMIN"},
|
||||
}, &network.NetworkingConfig{
|
||||
EndpointsConfig: map[string]*network.EndpointSettings{
|
||||
m.cfg.Network: {},
|
||||
},
|
||||
}, nil, containerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ls worker container: %w", err)
|
||||
}
|
||||
|
||||
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("start ls worker container: %w", err)
|
||||
}
|
||||
|
||||
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("inspect ls worker container: %w", err)
|
||||
}
|
||||
|
||||
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.logger.Info("created ls worker",
|
||||
"account", shortAccountID(accountID),
|
||||
"container", containerName,
|
||||
"address", address,
|
||||
"proxy_hash", proxyHash[:8])
|
||||
|
||||
return &workerHandle{
|
||||
Key: buildWorkerKey(accountID, proxyHash),
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: proxyHash,
|
||||
ContainerID: createResp.ID,
|
||||
Container: containerName,
|
||||
Address: address,
|
||||
AuthToken: authToken,
|
||||
LastUsed: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
|
||||
if inspect.NetworkSettings == nil {
|
||||
return "", fmt.Errorf("ls worker inspect missing network settings")
|
||||
}
|
||||
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
for _, endpoint := range inspect.NetworkSettings.Networks {
|
||||
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
|
||||
}
|
||||
|
||||
func firstContainerName(names []string) string {
|
||||
if len(names) == 0 {
|
||||
return ""
|
||||
}
|
||||
return names[0]
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
values := url.Values{}
|
||||
if strings.TrimSpace(routingKey) != "" {
|
||||
values.Set("routing_key", routingKey)
|
||||
}
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if len(body) > 0 {
|
||||
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
|
||||
if state == nil {
|
||||
return fmt.Errorf("ls worker state is nil")
|
||||
}
|
||||
body, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker state: %w", err)
|
||||
}
|
||||
|
||||
sum := fmt.Sprintf("%x", sha256.Sum256(body))
|
||||
if handle.LastStateSHA == sum {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
|
||||
resp, err := m.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync worker state: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
handle.LastStateSHA = sum
|
||||
return nil
|
||||
}
|
||||
|
||||
func workerURL(handle *workerHandle, path string, values url.Values) string {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: handle.Address,
|
||||
Path: path,
|
||||
}
|
||||
if values != nil {
|
||||
base.RawQuery = values.Encode()
|
||||
}
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
|
||||
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
|
||||
return nil
|
||||
}
|
||||
timeout := 5
|
||||
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
|
||||
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
|
||||
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
|
||||
state := m.state[accountID]
|
||||
if state == nil {
|
||||
state = &workerAccountState{}
|
||||
m.state[accountID] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
|
||||
resolved := resolveLSProxy(proxyURL)
|
||||
normalized, parsed, err := proxyurl.Parse(resolved)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "socks5", "socks5h":
|
||||
return normalized, parsed, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func proxyHash(proxyURL string) string {
|
||||
if strings.TrimSpace(proxyURL) == "" {
|
||||
return "direct"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
|
||||
return fmt.Sprintf("%x", sum[:])
|
||||
}
|
||||
|
||||
func buildWorkerKey(accountID, proxyHash string) string {
|
||||
return accountID + ":" + proxyHash
|
||||
}
|
||||
|
||||
func cloneInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *state
|
||||
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
|
||||
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
|
||||
if state.ExpiresAt != nil {
|
||||
ts := *state.ExpiresAt
|
||||
cp.ExpiresAt = &ts
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
266
backend/internal/pkg/lspool/worker_manager_test.go
Normal file
266
backend/internal/pkg/lspool/worker_manager_test.go
Normal file
@ -0,0 +1,266 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDockerClient struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listResp []container.Summary
|
||||
listCalls int
|
||||
createCalls int
|
||||
startCalls int
|
||||
stopCalls int
|
||||
removeCalls int
|
||||
inspectCalls int
|
||||
removedIDs []string
|
||||
createdConfigs []*container.Config
|
||||
inspectResp container.InspectResponse
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.listCalls++
|
||||
return append([]container.Summary(nil), f.listResp...), nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.createCalls++
|
||||
f.createdConfigs = append(f.createdConfigs, cfg)
|
||||
return container.CreateResponse{ID: "worker-created"}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.startCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inspectCalls++
|
||||
return f.inspectResp, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.stopCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removeCalls++
|
||||
f.removedIDs = append(f.removedIDs, containerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) Close() error { return nil }
|
||||
|
||||
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
|
||||
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "only supports socks5/socks5h")
|
||||
}
|
||||
|
||||
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
|
||||
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
|
||||
|
||||
hash1 := proxyHash(normalized)
|
||||
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestWorkerManagerRequiresToken(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 2,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var healthCalls int
|
||||
var readyCalls int
|
||||
var stateCalls int
|
||||
var stateBodies [][]byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/healthz":
|
||||
mu.Lock()
|
||||
healthCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
case "/readyz":
|
||||
mu.Lock()
|
||||
readyCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
case "/account/state":
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
mu.Lock()
|
||||
stateCalls++
|
||||
stateBodies = append(stateBodies, body)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
accountID := "9"
|
||||
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
|
||||
hash := proxyHash(proxyURL)
|
||||
key := buildWorkerKey(accountID, hash)
|
||||
|
||||
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers[key] = &workerHandle{
|
||||
Key: key,
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: hash,
|
||||
ContainerID: "existing-worker",
|
||||
Container: "sub2api-ls-9-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst1.remote)
|
||||
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
|
||||
|
||||
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst2.remote)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.GreaterOrEqual(t, healthCalls, 2)
|
||||
require.GreaterOrEqual(t, readyCalls, 2)
|
||||
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
|
||||
require.Len(t, stateBodies, 1)
|
||||
|
||||
var synced workerAccountState
|
||||
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
|
||||
require.True(t, synced.HasToken)
|
||||
require.Equal(t, "ya29.test", synced.AccessToken)
|
||||
}
|
||||
|
||||
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
|
||||
manager.mu.Unlock()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "limit reached")
|
||||
require.Equal(t, 0, fakeDocker.createCalls)
|
||||
}
|
||||
|
||||
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{
|
||||
listResp: []container.Summary{
|
||||
{
|
||||
ID: "old-worker-1",
|
||||
Names: []string{"/sub2api-ls-9-deadbeef"},
|
||||
},
|
||||
{
|
||||
ID: "old-worker-2",
|
||||
Names: []string{"/sub2api-ls-10-beadfeed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
require.Equal(t, 1, fakeDocker.listCalls)
|
||||
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
|
||||
}
|
||||
|
||||
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
368
backend/internal/pkg/lspool/worker_server.go
Normal file
368
backend/internal/pkg/lspool/worker_server.go
Normal file
@ -0,0 +1,368 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WorkerServerConfig struct {
|
||||
AccountID string
|
||||
AuthToken string
|
||||
ListenAddr string
|
||||
AppRoot string
|
||||
NetworkReadyFile string
|
||||
MaxIdleTime time.Duration
|
||||
HealthInterval time.Duration
|
||||
}
|
||||
|
||||
type WorkerServer struct {
|
||||
cfg WorkerServerConfig
|
||||
pool *Pool
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
state workerAccountState
|
||||
}
|
||||
|
||||
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
|
||||
if strings.TrimSpace(cfg.AccountID) == "" {
|
||||
return nil, fmt.Errorf("worker account id is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AuthToken) == "" {
|
||||
return nil, fmt.Errorf("worker auth token is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
||||
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
|
||||
}
|
||||
if strings.TrimSpace(cfg.AppRoot) == "" {
|
||||
cfg.AppRoot = "/app/ls"
|
||||
}
|
||||
if cfg.MaxIdleTime <= 0 {
|
||||
cfg.MaxIdleTime = 15 * time.Minute
|
||||
}
|
||||
if cfg.HealthInterval <= 0 {
|
||||
cfg.HealthInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
poolCfg := DefaultConfig()
|
||||
poolCfg.AppRoot = cfg.AppRoot
|
||||
poolCfg.MaxIdleTime = cfg.MaxIdleTime
|
||||
poolCfg.HealthCheckInterval = cfg.HealthInterval
|
||||
|
||||
return &WorkerServer{
|
||||
cfg: cfg,
|
||||
pool: NewPool(poolCfg),
|
||||
logger: slog.Default().With("component", "lsworker"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewWorkerServerFromEnv() (*WorkerServer, error) {
|
||||
maxIdleTime := 15 * time.Minute
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
maxIdleTime = parsed
|
||||
}
|
||||
}
|
||||
healthInterval := 30 * time.Second
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
healthInterval = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return NewWorkerServer(WorkerServerConfig{
|
||||
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
|
||||
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
|
||||
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
|
||||
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
|
||||
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
|
||||
MaxIdleTime: maxIdleTime,
|
||||
HealthInterval: healthInterval,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Close() {
|
||||
if s.pool != nil {
|
||||
s.pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", s.handleHealthz)
|
||||
mux.HandleFunc("/readyz", s.handleReadyz)
|
||||
mux.HandleFunc("/account/state", s.handleAccountState)
|
||||
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
|
||||
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
|
||||
return mux
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var payload workerAccountState
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "invalid account state payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.state = *cloneWorkerAccountState(&payload)
|
||||
s.mu.Unlock()
|
||||
s.applyState()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
respBody, err = inst.CallRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(respBody)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
resp, err = inst.StreamRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
|
||||
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
|
||||
return true
|
||||
}
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
|
||||
func subtleHeaderEqual(left, right string) bool {
|
||||
if left == "" || right == "" {
|
||||
return false
|
||||
}
|
||||
return left == right
|
||||
}
|
||||
|
||||
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
query := r.URL.Query()
|
||||
service = strings.TrimSpace(query.Get("service"))
|
||||
method = strings.TrimSpace(query.Get("method"))
|
||||
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
|
||||
routingKey = strings.TrimSpace(query.Get("routing_key"))
|
||||
if service == "" || method == "" {
|
||||
http.Error(w, "missing rpc target", http.StatusBadRequest)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "proto"
|
||||
}
|
||||
return service, method, mode, routingKey, true
|
||||
}
|
||||
|
||||
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
|
||||
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, fmt.Errorf("worker network not ready: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.applyState()
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("worker access token not configured")
|
||||
}
|
||||
|
||||
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inst.HasModelMappingReady() {
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
|
||||
defer cancel()
|
||||
_ = modelCtx
|
||||
if !RefreshModelMapping(inst) {
|
||||
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (s *WorkerServer) applyState() {
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
if state.HasToken {
|
||||
expiresAt := time.Time{}
|
||||
if state.ExpiresAt != nil {
|
||||
expiresAt = state.ExpiresAt.UTC()
|
||||
}
|
||||
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
|
||||
}
|
||||
if state.HasModelCredits {
|
||||
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
|
||||
}
|
||||
}
|
||||
|
||||
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func workerExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func parseWorkerControlPort() int {
|
||||
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
|
||||
if raw == "" {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
port, err := strconv.Atoi(raw)
|
||||
if err != nil || port < 1 {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
return port
|
||||
}
|
||||
@ -13,6 +13,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@ -99,16 +101,34 @@ type httpUpstreamService struct {
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
//
|
||||
// 当环境变量 ANTIGRAVITY_LS_MODE=true 时,自动包装 LS 池拦截层:
|
||||
// - 仅对已知兼容的 LS 请求形态启用转发
|
||||
// - 对普通 streamGenerateContent 请求保留原有直连路径,避免误送到不兼容的 LS RPC
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置,包含连接池参数和隔离策略
|
||||
//
|
||||
// 返回:
|
||||
// - service.HTTPUpstream 接口实现
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
return &httpUpstreamService{
|
||||
base := &httpUpstreamService{
|
||||
cfg: cfg,
|
||||
clients: make(map[string]*upstreamClientEntry),
|
||||
}
|
||||
|
||||
// LS 池模式: 包装一层拦截, streamGenerateContent 走 LS
|
||||
if lspool.IsLSModeEnabled() {
|
||||
pool := lspool.GlobalPool(cfg)
|
||||
if pool != nil {
|
||||
slog.Info("LS pool mode enabled — streamGenerateContent will route through Language Server",
|
||||
"component", "http_upstream")
|
||||
return lspool.NewLSPoolUpstream(pool, base)
|
||||
}
|
||||
slog.Warn("LS pool mode enabled but pool is nil — falling back to direct mode",
|
||||
"component", "http_upstream")
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// Do 执行 HTTP 请求
|
||||
|
||||
@ -104,6 +104,10 @@ func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
if strings.Contains(lowerBody, "exhausted your capacity on this model") &&
|
||||
strings.Contains(lowerBody, "quota will reset after") {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
|
||||
@ -21,6 +21,16 @@ func TestClassifyAntigravity429(t *testing.T) {
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("模型配额耗尽文案也视为可切 AI Credits", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}`)
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("结构化限流", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
@ -146,6 +156,68 @@ func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t
|
||||
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_ModelQuotaMessage_UsesCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 151,
|
||||
Name: "acc-151",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-opus-4-6",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
|
||||
@ -112,6 +112,144 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// injectLSPoolHeaders adds internal headers carrying OAuth credentials for the
|
||||
// LS pool layer. These headers are consumed and stripped by LSPoolUpstream
|
||||
// before the request reaches the Language Server. When LS mode is disabled,
|
||||
// these headers are harmless — the direct upstream never sees them because
|
||||
// they are stripped inside LSPoolUpstream.Do(). In direct mode the request
|
||||
// goes straight through httpUpstreamService.Do() which doesn't inspect them.
|
||||
func injectLSPoolHeaders(req *http.Request, account *Account) {
|
||||
if req == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if rt, ok := account.Credentials["refresh_token"].(string); ok && rt != "" {
|
||||
req.Header.Set("X-Antigravity-Refresh-Token", rt)
|
||||
}
|
||||
if ea, ok := account.Credentials["expires_at"].(string); ok && ea != "" {
|
||||
req.Header.Set("X-Antigravity-Token-Expiry", ea)
|
||||
}
|
||||
req.Header.Set("X-Antigravity-Use-AI-Credits", strconv.FormatBool(account.IsOveragesEnabled()))
|
||||
|
||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
||||
if availableCredits != nil {
|
||||
req.Header.Set("X-Antigravity-Available-Credits", strconv.FormatInt(int64(*availableCredits), 10))
|
||||
}
|
||||
if minimumCreditAmount != nil {
|
||||
req.Header.Set("X-Antigravity-Minimum-Credit-Amount", strconv.FormatInt(int64(*minimumCreditAmount), 10))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveLSPoolModelCreditsState(account *Account) (*int32, *int32) {
|
||||
if account == nil || account.Extra == nil {
|
||||
minimum := int32(50)
|
||||
return nil, &minimum
|
||||
}
|
||||
|
||||
var availableCredits *int32
|
||||
var minimumCreditAmount *int32
|
||||
|
||||
collect := func(entry map[string]any) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
if !isGoogleOneAICreditsEntry(entry) {
|
||||
return
|
||||
}
|
||||
if availableCredits == nil {
|
||||
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "Amount", "amount", "creditAmount")); ok {
|
||||
availableCredits = &parsed
|
||||
}
|
||||
}
|
||||
if minimumCreditAmount == nil {
|
||||
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "MinimumBalance", "minimum_balance", "minimumCreditAmountForUsage")); ok {
|
||||
minimumCreditAmount = &parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawCredits, ok := account.Extra["ai_credits"].([]any); ok {
|
||||
for _, item := range rawCredits {
|
||||
if entry, ok := item.(map[string]any); ok {
|
||||
collect(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if loadCodeAssist, ok := account.Extra["load_code_assist"].(map[string]any); ok {
|
||||
if paidTier, ok := loadCodeAssist["paidTier"].(map[string]any); ok {
|
||||
if credits, ok := paidTier["availableCredits"].([]any); ok {
|
||||
for _, item := range credits {
|
||||
if entry, ok := item.(map[string]any); ok {
|
||||
collect(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if minimumCreditAmount == nil {
|
||||
defaultMinimum := int32(50)
|
||||
minimumCreditAmount = &defaultMinimum
|
||||
}
|
||||
return availableCredits, minimumCreditAmount
|
||||
}
|
||||
|
||||
func isGoogleOneAICreditsEntry(entry map[string]any) bool {
|
||||
creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string)
|
||||
creditType = strings.TrimSpace(strings.ToUpper(creditType))
|
||||
return creditType == "" || creditType == "GOOGLE_ONE_AI"
|
||||
}
|
||||
|
||||
func firstPresent(entry map[string]any, keys ...string) any {
|
||||
for _, key := range keys {
|
||||
if value, ok := entry[key]; ok {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAICreditsInt32(raw any) (int32, bool) {
|
||||
switch v := raw.(type) {
|
||||
case int:
|
||||
return int32(v), true
|
||||
case int32:
|
||||
return v, true
|
||||
case int64:
|
||||
return int32(v), true
|
||||
case float32:
|
||||
return int32(v), true
|
||||
case float64:
|
||||
return int32(v), true
|
||||
case json.Number:
|
||||
parsed, err := v.Int64()
|
||||
if err != nil {
|
||||
floatVal, floatErr := strconv.ParseFloat(v.String(), 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
}
|
||||
return int32(parsed), true
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsed, err := strconv.ParseInt(trimmed, 10, 32)
|
||||
if err == nil {
|
||||
return int32(parsed), true
|
||||
}
|
||||
floatVal, floatErr := strconv.ParseFloat(trimmed, 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
@ -305,6 +443,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
}
|
||||
|
||||
injectLSPoolHeaders(retryReq, p.account)
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
|
||||
@ -489,6 +628,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
|
||||
break
|
||||
}
|
||||
injectLSPoolHeaders(retryReq, p.account)
|
||||
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
@ -627,6 +767,7 @@ urlFallbackLoop:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
injectLSPoolHeaders(upstreamReq, p.account)
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if p.c != nil && len(p.body) > 0 {
|
||||
@ -1289,9 +1430,19 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) {
|
||||
sessionID := ""
|
||||
if len(preferredSessionID) > 0 {
|
||||
sessionID = preferredSessionID[0]
|
||||
}
|
||||
|
||||
bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("补全 sessionId 失败: %w", err)
|
||||
}
|
||||
|
||||
var request any
|
||||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||||
if err := json.Unmarshal(bodyWithSessionID, &request); err != nil {
|
||||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||||
}
|
||||
|
||||
@ -2156,7 +2307,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID)
|
||||
if err != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
|
||||
}
|
||||
@ -2220,10 +2371,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
injectLSPoolHeaders(fallbackReq, account)
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
@ -2263,7 +2415,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||
|
||||
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID)
|
||||
if wrapErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
|
||||
@ -600,6 +600,63 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
req.Header.Set("session_id", "session-header-1")
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-session-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Name: "acc-gemini-session",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
|
||||
var wrapped map[string]any
|
||||
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
|
||||
requestNode, ok := wrapped["request"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "session-header-1", requestNode["sessionId"])
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
|
||||
@ -2500,6 +2500,57 @@
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"gemini-3-flash": {
|
||||
"cache_read_input_token_cost": 5e-08,
|
||||
"cache_read_input_token_cost_priority": 9e-08,
|
||||
"input_cost_per_audio_token": 1e-06,
|
||||
"input_cost_per_audio_token_priority": 1.8e-06,
|
||||
"input_cost_per_token": 5e-07,
|
||||
"input_cost_per_token_priority": 9e-07,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 65535,
|
||||
"max_pdf_size_mb": 30,
|
||||
"max_tokens": 65535,
|
||||
"max_video_length": 1,
|
||||
"max_videos_per_prompt": 10,
|
||||
"mode": "chat",
|
||||
"output_cost_per_reasoning_token": 3e-06,
|
||||
"output_cost_per_token": 3e-06,
|
||||
"output_cost_per_token_priority": 5.4e-06,
|
||||
"source": "https://ai.google.dev/pricing/gemini-3",
|
||||
"supported_endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/batch"
|
||||
],
|
||||
"supported_modalities": [
|
||||
"text",
|
||||
"image",
|
||||
"audio",
|
||||
"video"
|
||||
],
|
||||
"supported_output_modalities": [
|
||||
"text"
|
||||
],
|
||||
"supports_audio_output": false,
|
||||
"supports_function_calling": true,
|
||||
"supports_native_streaming": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_pdf_input": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_service_tier": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_url_context": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"gemini-3-flash-preview": {
|
||||
"cache_read_input_token_cost": 5e-08,
|
||||
"cache_read_input_token_cost_priority": 9e-08,
|
||||
|
||||
@ -20,6 +20,9 @@ SERVER_PORT=8080
|
||||
# Server mode: release or debug
|
||||
SERVER_MODE=release
|
||||
|
||||
# Main application image override
|
||||
SUB2API_IMAGE=zfc931912343/sub2api:latest
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging Configuration
|
||||
# 日志配置
|
||||
@ -389,3 +392,32 @@ OPS_ENABLED=true
|
||||
# Leave empty for direct connection (recommended for overseas servers)
|
||||
# 留空表示直连(适用于海外服务器)
|
||||
UPDATE_PROXY_URL=
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Language Server Pool Mode (Enhanced Security)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Enable to route requests through real AntiGravity LS binary
|
||||
# Makes upstream traffic indistinguishable from real IDE
|
||||
# ANTIGRAVITY_LS_MODE=true
|
||||
# LS replicas per account. Default is 5.
|
||||
# Increase for higher concurrency, but each replica is an extra LS process.
|
||||
# ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=5
|
||||
# Optional global fallback proxy for accounts without dedicated LS proxy.
|
||||
# Must be socks5/socks5h in worker mode.
|
||||
ANTIGRAVITY_LS_PROXY=
|
||||
# LS routing strategy (default js-parity)
|
||||
ANTIGRAVITY_LS_STRATEGY=js-parity
|
||||
# Dynamic LS worker container image. Build/pull this image before enabling LS mode.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=zfc931912343/sub2api-lsworker:latest
|
||||
# Docker network name shared by sub2api and dynamic ls-worker containers.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=sub2api-network
|
||||
# Docker socket used by sub2api to create dynamic ls-worker containers.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=unix:///var/run/docker.sock
|
||||
# Idle TTL before worker container is reaped.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=15m
|
||||
# Maximum number of active worker containers on this node.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=50
|
||||
# Maximum time allowed for worker cold start and readiness.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=45s
|
||||
# Per-request timeout when sub2api talks to worker control API.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=60s
|
||||
|
||||
@ -65,6 +65,20 @@ docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
# http://localhost:8080
|
||||
```
|
||||
|
||||
### LS Worker Image
|
||||
|
||||
When `ANTIGRAVITY_LS_MODE=true`, Sub2API creates dynamic `ls-worker`
|
||||
containers through the Docker socket. Build or pull the worker image before
|
||||
enabling LS mode:
|
||||
|
||||
```bash
|
||||
cd /path/to/sub2api
|
||||
docker build -f deploy/lsworker.Dockerfile -t weishaw/sub2api-lsworker:latest .
|
||||
```
|
||||
|
||||
The `sub2api` container must also be able to access `/var/run/docker.sock`,
|
||||
and the shared Docker network name must remain fixed at `sub2api-network`.
|
||||
|
||||
### Method 2: Manual Deployment
|
||||
|
||||
If you prefer manual control:
|
||||
|
||||
@ -283,6 +283,30 @@ gateway:
|
||||
queue: 0.7
|
||||
error_rate: 0.8
|
||||
ttft: 0.5
|
||||
# Antigravity LS worker container configuration
|
||||
# Antigravity LS worker 容器控制平面配置
|
||||
antigravity_ls_worker:
|
||||
# Worker image used by sub2api to create dynamic LS containers
|
||||
# sub2api 用于创建动态 LS worker 的镜像
|
||||
image: "weishaw/sub2api-lsworker:latest"
|
||||
# Docker network name shared by sub2api and workers
|
||||
# sub2api 与 worker 共享的 Docker network 名称
|
||||
network: "sub2api-network"
|
||||
# Docker socket path or host used by sub2api control plane
|
||||
# sub2api 控制面访问的 Docker socket / host
|
||||
docker_socket: "unix:///var/run/docker.sock"
|
||||
# Idle TTL before a worker container is recycled
|
||||
# worker 容器空闲回收时间
|
||||
idle_ttl: 15m
|
||||
# Max active worker containers per node
|
||||
# 单节点最大 worker 容器数量
|
||||
max_active: 50
|
||||
# Worker cold-start timeout
|
||||
# worker 冷启动超时
|
||||
startup_timeout: 45s
|
||||
# Timeout for control-plane calls from sub2api to worker
|
||||
# sub2api 调用 worker 控制接口的超时
|
||||
request_timeout: 60s
|
||||
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
|
||||
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
|
||||
# Max idle connections across all hosts
|
||||
|
||||
@ -36,6 +36,7 @@ services:
|
||||
volumes:
|
||||
# Local directory mapping for easy migration
|
||||
- ./data:/app/data
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# Optional: Mount custom config.yaml (uncomment and create the file first)
|
||||
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
|
||||
# - ./config.yaml:/app/data/config.yaml
|
||||
@ -128,6 +129,22 @@ services:
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# =======================================================================
|
||||
# Language Server Worker Mode
|
||||
# =======================================================================
|
||||
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
|
||||
- ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
|
||||
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
|
||||
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
|
||||
|
||||
# =======================================================================
|
||||
# Security Configuration (URL Allowlist)
|
||||
# =======================================================================
|
||||
@ -230,4 +247,5 @@ services:
|
||||
# =============================================================================
|
||||
networks:
|
||||
sub2api-network:
|
||||
name: sub2api-network
|
||||
driver: bridge
|
||||
|
||||
@ -16,7 +16,8 @@ services:
|
||||
# Sub2API Application
|
||||
# ===========================================================================
|
||||
sub2api:
|
||||
image: weishaw/sub2api:latest
|
||||
# Override with SUB2API_IMAGE to use a private registry or pinned tag.
|
||||
image: ${SUB2API_IMAGE:-weishaw/sub2api:latest}
|
||||
container_name: sub2api
|
||||
restart: unless-stopped
|
||||
ulimits:
|
||||
@ -28,6 +29,7 @@ services:
|
||||
volumes:
|
||||
# Data persistence (config.yaml will be auto-generated here)
|
||||
- sub2api_data:/app/data
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# Optional: Mount custom config.yaml (uncomment and create the file first)
|
||||
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
|
||||
# - ./config.yaml:/app/data/config.yaml
|
||||
@ -120,6 +122,26 @@ services:
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# =======================================================================
|
||||
# Language Server Pool Mode (Enhanced Security)
|
||||
# =======================================================================
|
||||
# Enable to route requests through real LS binary (Google's own code)
|
||||
# This makes upstream traffic indistinguishable from real IDE
|
||||
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
|
||||
- ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
# SOCKS5/HTTP proxy fallback used when account has no dedicated LS proxy
|
||||
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
|
||||
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
|
||||
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
|
||||
# Keep the worker image aligned with the main image release when overriding.
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
|
||||
|
||||
# =======================================================================
|
||||
# Security Configuration (URL Allowlist)
|
||||
# =======================================================================
|
||||
@ -234,4 +256,5 @@ volumes:
|
||||
# =============================================================================
|
||||
networks:
|
||||
sub2api-network:
|
||||
name: sub2api-network
|
||||
driver: bridge
|
||||
|
||||
@ -8,9 +8,27 @@ if [ "$(id -u)" = "0" ]; then
|
||||
mkdir -p /app/data
|
||||
# Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro)
|
||||
chown -R sub2api:sub2api /app/data 2>/dev/null || true
|
||||
if [ -S /var/run/docker.sock ]; then
|
||||
DOCKER_GID="$(stat -c '%g' /var/run/docker.sock 2>/dev/null || true)"
|
||||
if [ -n "${DOCKER_GID}" ]; then
|
||||
DOCKER_GROUP="$(getent group "${DOCKER_GID}" | cut -d: -f1 || true)"
|
||||
if [ -z "${DOCKER_GROUP}" ]; then
|
||||
DOCKER_GROUP="dockersock"
|
||||
groupadd -for -g "${DOCKER_GID}" "${DOCKER_GROUP}" 2>/dev/null || true
|
||||
fi
|
||||
usermod -aG "${DOCKER_GROUP}" sub2api 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
# Re-invoke this script as sub2api so the flag-detection below
|
||||
# also runs under the correct user.
|
||||
exec su-exec sub2api "$0" "$@"
|
||||
# Use gosu if available (Debian), fall back to su-exec (Alpine)
|
||||
if command -v gosu >/dev/null 2>&1; then
|
||||
exec gosu sub2api "$0" "$@"
|
||||
elif command -v su-exec >/dev/null 2>&1; then
|
||||
exec su-exec sub2api "$0" "$@"
|
||||
else
|
||||
exec su -s /bin/sh sub2api -c "exec $0 $*"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Compatibility: if the first arg looks like a flag (e.g. --help),
|
||||
|
||||
21
deploy/ls-bin/cert.pem
Normal file
21
deploy/ls-bin/cert.pem
Normal file
@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
|
||||
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
|
||||
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
|
||||
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
|
||||
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
|
||||
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
|
||||
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
|
||||
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
|
||||
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
|
||||
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
|
||||
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
|
||||
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
|
||||
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
|
||||
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
|
||||
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
|
||||
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
|
||||
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
|
||||
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
|
||||
/Q==
|
||||
-----END CERTIFICATE-----
|
||||
BIN
deploy/ls-bin/language_server_linux_arm
Executable file
BIN
deploy/ls-bin/language_server_linux_arm
Executable file
Binary file not shown.
BIN
deploy/ls-bin/language_server_linux_x64
Executable file
BIN
deploy/ls-bin/language_server_linux_x64
Executable file
Binary file not shown.
70
deploy/lsworker-entrypoint.sh
Normal file
70
deploy/lsworker-entrypoint.sh
Normal file
@ -0,0 +1,70 @@
|
||||
#!/bin/sh
|
||||
set -eu
|
||||
|
||||
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
|
||||
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
|
||||
PROXY_USER="${LSWORKER_PROXY_USER:-}"
|
||||
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
|
||||
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
|
||||
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
|
||||
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
|
||||
|
||||
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
|
||||
|
||||
if [ -z "${PROXY_HOST}" ]; then
|
||||
echo "LSWORKER_PROXY_HOST is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
|
||||
if [ -z "${PROXY_IP}" ]; then
|
||||
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat >/tmp/redsocks.conf <<EOF
|
||||
base {
|
||||
log_debug = off;
|
||||
log_info = on;
|
||||
daemon = off;
|
||||
redirector = iptables;
|
||||
}
|
||||
|
||||
redsocks {
|
||||
local_ip = 0.0.0.0;
|
||||
local_port = ${REDSOCKS_PORT};
|
||||
ip = ${PROXY_IP};
|
||||
port = ${PROXY_PORT};
|
||||
type = socks5;
|
||||
EOF
|
||||
|
||||
if [ -n "${PROXY_USER}" ]; then
|
||||
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
if [ -n "${PROXY_PASS}" ]; then
|
||||
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
|
||||
cat >>/tmp/redsocks.conf <<EOF
|
||||
}
|
||||
EOF
|
||||
|
||||
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
|
||||
REDSOCKS_PID="$!"
|
||||
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
|
||||
|
||||
sleep 1
|
||||
|
||||
iptables -t nat -N REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -F REDSOCKS
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
|
||||
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
|
||||
|
||||
touch "${NETWORK_READY_FILE}"
|
||||
|
||||
exec gosu sub2api /app/lsworker
|
||||
52
deploy/lsworker.Dockerfile
Normal file
52
deploy/lsworker.Dockerfile
Normal file
@ -0,0 +1,52 @@
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
|
||||
FROM ${GOLANG_IMAGE} AS builder
|
||||
|
||||
WORKDIR /app/backend
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY backend/ ./
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
|
||||
|
||||
FROM ${DEBIAN_IMAGE}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gosu \
|
||||
iproute2 \
|
||||
iptables \
|
||||
redsocks \
|
||||
tzdata \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN groupadd -g 1000 sub2api && \
|
||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/lsworker /app/lsworker
|
||||
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
|
||||
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
chown -R sub2api:sub2api /app /run/lsworker && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
|
||||
RUN chmod +x /app/lsworker-entrypoint.sh
|
||||
|
||||
EXPOSE 18081
|
||||
|
||||
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]
|
||||
@ -510,6 +510,7 @@ const handleEvent = (event: {
|
||||
addLine(streamingContent.value, 'text-green-300')
|
||||
streamingContent.value = ''
|
||||
}
|
||||
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -510,6 +510,7 @@ const handleEvent = (event: {
|
||||
addLine(streamingContent.value, 'text-green-300')
|
||||
streamingContent.value = ''
|
||||
}
|
||||
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,4 +144,28 @@ describe('AccountTestModal', () => {
|
||||
expect(preview.exists()).toBe(true)
|
||||
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
|
||||
})
|
||||
|
||||
it('收到 error 事件时会把错误内容显示在终端输出里', async () => {
|
||||
;(global.fetch as any).mockResolvedValueOnce(
|
||||
createStreamResponse([
|
||||
'data: {"type":"test_start","model":"claude-opus-4-6"}\n',
|
||||
'data: {"type":"error","error":"API returned 429: You have exhausted your capacity on this model."}\n'
|
||||
])
|
||||
)
|
||||
|
||||
const wrapper = mountModal()
|
||||
await wrapper.setProps({ show: true })
|
||||
await flushPromises()
|
||||
|
||||
const buttons = wrapper.findAll('button')
|
||||
const startButton = buttons.find((button) => button.text().includes('admin.accounts.startTest'))
|
||||
expect(startButton).toBeTruthy()
|
||||
|
||||
await startButton!.trigger('click')
|
||||
await flushPromises()
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('API returned 429: You have exhausted your capacity on this model.')
|
||||
expect(wrapper.text()).toContain('Error: API returned 429: You have exhausted your capacity on this model.')
|
||||
})
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user