chore: remove LS pool implementation
Removing all LS (Language Server Pool) related code: - backend/cmd/lsworker/ - backend/internal/pkg/lspool/ - backend/internal/service/lspool_bootstrap_service.* - deploy/ls-bin/ - deploy/lsworker.Dockerfile - deploy/lsworker-entrypoint.sh Keeping: - Claude custom fingerprint (immutable) - Antigravity OAuth and telemetry improvements - TLS fingerprint SOCKS5 Docker DNS fix - Gemini OAuth security improvements Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
3ba3a17652
commit
a3f2d4577e
@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.25-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM golang:1.25.7-alpine
|
||||
FROM golang:1.25-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
func main() {
|
||||
server, err := lspool.NewWorkerServerFromEnv()
|
||||
if err != nil {
|
||||
slog.Error("failed to initialize lsworker", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
|
||||
Handler: server.Handler(),
|
||||
ReadHeaderTimeout: 10 * 1e9,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = httpServer.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
slog.Info("lsworker listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
slog.Error("lsworker exited with error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@ -53,9 +53,8 @@ const (
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
||||
var defaultUserAgentVersion = "1.107.0"
|
||||
|
||||
|
||||
// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret string
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
// Package claude provides constants and helpers for Claude API integration.
|
||||
package claude
|
||||
|
||||
import "strings"
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// DefaultCLIVersion 是当前模拟的 Claude CLI 版本
|
||||
@ -30,32 +32,64 @@ const (
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header(OAuth 账号,不含 context-1m)
|
||||
// 使用 GetOAuthBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header(OAuth,不含 context-1m)
|
||||
//
|
||||
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
||||
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||
// request as a non-Claude-Code API request.
|
||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header(OAuth,不含 context-1m)
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(OAuth,不含 claude-code / context-1m)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort
|
||||
|
||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||
// APIKeyBetaHeader API-key 账号使用的 anthropic-beta header(不含 oauth / context-1m)
|
||||
// 使用 GetAPIKeyBetaHeader(modelID) 获取含 context-1m 的 model-aware 版本。
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort
|
||||
|
||||
// ModelSupports1M 判断模型是否支持 1M context window。
|
||||
// 与 claude-code-2.1.88 bundle 中 modelSupports1M 逻辑保持一致:
|
||||
//
|
||||
// claude-sonnet-4 系列 和 claude-opus-4-6 支持 1M context。
|
||||
func ModelSupports1M(modelID string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||
return strings.Contains(lower, "claude-sonnet-4") || strings.Contains(lower, "opus-4-6")
|
||||
}
|
||||
|
||||
// GetOAuthBetaHeader 返回 OAuth 账号的 beta header。
|
||||
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
|
||||
func GetOAuthBetaHeader(modelID string) string {
|
||||
if ModelSupports1M(modelID) {
|
||||
return BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
}
|
||||
return DefaultBetaHeader
|
||||
}
|
||||
|
||||
// GetAPIKeyBetaHeader 返回 API-key 账号的 beta header。
|
||||
// 仅当模型支持 1M context 时才包含 context-1m-2025-08-07。
|
||||
func GetAPIKeyBetaHeader(modelID string) string {
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return APIKeyHaikuBetaHeader
|
||||
}
|
||||
if ModelSupports1M(modelID) {
|
||||
return BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||
}
|
||||
return APIKeyBetaHeader
|
||||
}
|
||||
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||
@ -70,7 +104,7 @@ var DefaultHeaders = map[string]string{
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "600",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值)
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import "time"
|
||||
|
||||
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
|
||||
// It may be backed by a local in-process Pool or by remote LS workers.
|
||||
type Backend interface {
|
||||
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
|
||||
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
|
||||
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
|
||||
Stats() map[string]any
|
||||
Close()
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
// Package lspool provides LS-mode integration for the antigravity gateway.
|
||||
//
|
||||
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
|
||||
// streamGenerateContent are routed through a real Language Server instance
|
||||
// instead of directly to cloudcode-pa. This provides:
|
||||
//
|
||||
// - Authentic TLS fingerprint (Google's own Go binary)
|
||||
// - Real session management and Heartbeat
|
||||
// - Indistinguishable from a real IDE instance
|
||||
//
|
||||
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
|
||||
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
globalBackend Backend
|
||||
globalPoolOnce sync.Once
|
||||
lsModeEnabled bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
|
||||
}
|
||||
|
||||
// IsLSModeEnabled returns whether LS mode is active
|
||||
func IsLSModeEnabled() bool {
|
||||
return lsModeEnabled
|
||||
}
|
||||
|
||||
const (
|
||||
LSStrategyDirect = "direct"
|
||||
LSStrategyJSParity = "js-parity"
|
||||
)
|
||||
|
||||
// CurrentLSStrategy returns the active LS routing strategy.
|
||||
// Unknown values are treated as "direct" for safety.
|
||||
func CurrentLSStrategy() string {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
|
||||
case "", LSStrategyDirect:
|
||||
return LSStrategyDirect
|
||||
case LSStrategyJSParity:
|
||||
return LSStrategyJSParity
|
||||
default:
|
||||
return LSStrategyDirect
|
||||
}
|
||||
}
|
||||
|
||||
// GlobalPool returns the singleton LS pool instance
|
||||
// Creates it on first call if LS mode is enabled
|
||||
func GlobalPool(cfg *config.Config) Backend {
|
||||
if !lsModeEnabled {
|
||||
return nil
|
||||
}
|
||||
globalPoolOnce.Do(func() {
|
||||
manager, err := NewWorkerManagerFromConfig(cfg)
|
||||
if err != nil {
|
||||
slog.Default().Error("failed to initialize LS worker manager", "err", err)
|
||||
return
|
||||
}
|
||||
globalBackend = manager
|
||||
})
|
||||
return globalBackend
|
||||
}
|
||||
|
||||
// Shutdown closes the global pool
|
||||
func Shutdown() {
|
||||
if globalBackend != nil {
|
||||
globalBackend.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// StatusInfo returns the current LS pool status for diagnostics
|
||||
func StatusInfo() map[string]any {
|
||||
info := map[string]any{
|
||||
"ls_mode_enabled": lsModeEnabled,
|
||||
"build": "enhanced",
|
||||
"user_agent": "antigravity/1.107.0",
|
||||
}
|
||||
if lsModeEnabled && globalBackend != nil {
|
||||
stats := globalBackend.Stats()
|
||||
info["pool_total"] = stats["total"]
|
||||
info["pool_active"] = stats["active"]
|
||||
}
|
||||
return info
|
||||
}
|
||||
@ -1,864 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func readConnectFrame(r io.Reader) ([]byte, error) {
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen := binary.BigEndian.Uint32(header[1:5])
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeProtoBytesField(data []byte, targetField int) []byte {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
switch wireType {
|
||||
case 0:
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
case 2:
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
if i+int(length) > len(data) {
|
||||
return nil
|
||||
}
|
||||
if fieldNum == targetField {
|
||||
return data[i : i+int(length)]
|
||||
}
|
||||
i += int(length)
|
||||
case 1:
|
||||
i += 8
|
||||
case 5:
|
||||
i += 4
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
|
||||
var values [][]byte
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
switch wireType {
|
||||
case 0:
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
case 2:
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
if i+int(length) > len(data) {
|
||||
return values
|
||||
}
|
||||
if fieldNum == targetField {
|
||||
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
|
||||
}
|
||||
i += int(length)
|
||||
case 1:
|
||||
i += 8
|
||||
case 5:
|
||||
i += 4
|
||||
default:
|
||||
return values
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func decodeTopicRows(topic []byte) map[string]string {
|
||||
rows := make(map[string]string)
|
||||
for _, entry := range decodeProtoBytesFields(topic, 1) {
|
||||
key := decodeProtoString(entry, 1)
|
||||
row := decodeProtoBytesField(entry, 2)
|
||||
rows[key] = decodeProtoString(row, 1)
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
|
||||
t.Helper()
|
||||
decoded, err := base64.StdEncoding.DecodeString(got)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, decoded)
|
||||
}
|
||||
|
||||
// TestMockExtensionServerTokenInjection verifies the token injection flow:
|
||||
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
|
||||
func TestMockExtensionServerTokenInjection(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
// 1. Set token for an account
|
||||
srv.SetToken("account-1", &TokenInfo{
|
||||
AccessToken: "ya29.test-access-token",
|
||||
RefreshToken: "1//test-refresh-token",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
|
||||
// 2. Verify token is stored
|
||||
srv.mu.RLock()
|
||||
info, ok := srv.tokens["account-1"]
|
||||
srv.mu.RUnlock()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ya29.test-access-token", info.AccessToken)
|
||||
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
|
||||
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
// The stream handler will block, so run in background and cancel after we confirm connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Read the first envelope frame (initial state)
|
||||
header := make([]byte, 5)
|
||||
n, readErr := resp.Body.Read(header)
|
||||
if readErr == nil && n == 5 {
|
||||
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
|
||||
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMockExtensionServerCSRF verifies CSRF token validation
|
||||
func TestMockExtensionServerCSRF(t *testing.T) {
|
||||
csrf := "correct-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
|
||||
|
||||
// Wrong CSRF → 403
|
||||
req, _ := http.NewRequest("POST", base, nil)
|
||||
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 403, resp.StatusCode)
|
||||
|
||||
// Correct CSRF → 200
|
||||
req2, _ := http.NewRequest("POST", base, nil)
|
||||
req2.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req2.Header.Set("Content-Type", "application/proto")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
require.NoError(t, err)
|
||||
defer resp2.Body.Close()
|
||||
require.Equal(t, 200, resp2.StatusCode)
|
||||
}
|
||||
|
||||
// TestMockExtensionServerGetSecretValue verifies the fallback token path
|
||||
func TestMockExtensionServerGetSecretValue(t *testing.T) {
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
|
||||
|
||||
// GetSecretValue should return the token
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
|
||||
nil)
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
|
||||
func TestOAuthTokenInfoProto(t *testing.T) {
|
||||
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
|
||||
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
|
||||
|
||||
// Verify fields are present by checking proto wire format
|
||||
require.True(t, len(bin) > 0, "proto should not be empty")
|
||||
|
||||
// Field 1 (access_token): tag=0x0a, value="ya29.test"
|
||||
require.Contains(t, string(bin), "ya29.test")
|
||||
// Field 2 (token_type): tag=0x12, value="Bearer"
|
||||
require.Contains(t, string(bin), "Bearer")
|
||||
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
|
||||
require.Contains(t, string(bin), "1//refresh")
|
||||
|
||||
// Without refresh_token
|
||||
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
|
||||
require.NotContains(t, string(binNoRefresh), "1//refresh")
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
|
||||
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
|
||||
future := time.Now().Add(2 * time.Hour)
|
||||
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
|
||||
|
||||
// Zero expiry should default to ~1h
|
||||
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
|
||||
|
||||
// They should be different lengths or content (different expiry timestamps)
|
||||
// Both should be valid (non-empty)
|
||||
require.True(t, len(bin) > 0)
|
||||
require.True(t, len(binZero) > 0)
|
||||
}
|
||||
|
||||
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
|
||||
func TestUSSTopicWithOAuth(t *testing.T) {
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
// The topic should contain the sentinel key
|
||||
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
|
||||
}
|
||||
|
||||
func TestUSSTopicWithModelCredits(t *testing.T) {
|
||||
available := int32(123)
|
||||
minimum := int32(50)
|
||||
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
require.Contains(t, string(topic), useAICreditsSentinelKey)
|
||||
require.Contains(t, string(topic), availableCreditsSentinelKey)
|
||||
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
|
||||
|
||||
rows := decodeTopicRows(topic)
|
||||
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
||||
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
||||
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
||||
}
|
||||
|
||||
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
|
||||
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
// Drain the initial_state frame first.
|
||||
_, err = readConnectFrame(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
values := make(map[string]string, 3)
|
||||
for len(values) < 3 {
|
||||
frame, readErr := readConnectFrame(resp.Body)
|
||||
require.NoError(t, readErr)
|
||||
applied := decodeProtoBytesField(frame, 2)
|
||||
require.NotEmpty(t, applied)
|
||||
key := decodeProtoString(applied, 1)
|
||||
row := decodeProtoBytesField(applied, 2)
|
||||
values[key] = decodeProtoString(row, 1)
|
||||
}
|
||||
|
||||
require.Contains(t, values, useAICreditsSentinelKey)
|
||||
require.Contains(t, values, availableCreditsSentinelKey)
|
||||
require.Contains(t, values, minimumCreditAmountForUsageKey)
|
||||
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
||||
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
||||
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
||||
}
|
||||
|
||||
// TestBuildInitialStateUpdate verifies the USS update wrapper
|
||||
func TestBuildInitialStateUpdate(t *testing.T) {
|
||||
topicData := buildEmptyTopic()
|
||||
update := buildInitialStateUpdate(topicData)
|
||||
// Should be a valid proto bytes field (field 1 = initial_state)
|
||||
require.True(t, len(update) >= 0) // empty topic is valid
|
||||
|
||||
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
|
||||
update2 := buildInitialStateUpdate(topicData2)
|
||||
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
|
||||
}
|
||||
|
||||
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
|
||||
func TestPoolSetAccountTokenComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.tokens["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "ya29.full-token", info.AccessToken)
|
||||
require.Equal(t, "1//full-refresh", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.credits["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.True(t, info.UseAICredits)
|
||||
require.NotNil(t, info.AvailableCredits)
|
||||
require.Equal(t, available, *info.AvailableCredits)
|
||||
require.NotNil(t, info.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
|
||||
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
|
||||
// Create a mock upstream that records what it receives
|
||||
var receivedHeaders http.Header
|
||||
var mu sync.Mutex
|
||||
fallback := &recordingUpstreamWithCallback{}
|
||||
fallback.onDo = func(req *http.Request) {
|
||||
mu.Lock()
|
||||
receivedHeaders = req.Header.Clone()
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
}
|
||||
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
// Non-streamGenerateContent request → should pass through to fallback
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
|
||||
req.Header.Set("Authorization", "Bearer ya29.test")
|
||||
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
|
||||
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
|
||||
req.Header.Set(useAICreditsHeader, "true")
|
||||
req.Header.Set(availableCreditsHeader, "42")
|
||||
req.Header.Set(minimumCreditAmountHeader, "50")
|
||||
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Internal headers should never leak to the direct upstream.
|
||||
mu.Lock()
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
|
||||
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
|
||||
mu.Unlock()
|
||||
|
||||
srv.mu.RLock()
|
||||
tokenInfo := srv.tokens["1"]
|
||||
creditsInfo := srv.credits["1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, tokenInfo)
|
||||
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
|
||||
require.NotNil(t, creditsInfo)
|
||||
require.True(t, creditsInfo.UseAICredits)
|
||||
require.NotNil(t, creditsInfo.AvailableCredits)
|
||||
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
|
||||
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
|
||||
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
|
||||
body := `{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"request": {
|
||||
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||
]
|
||||
}
|
||||
}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "claude-sonnet-4-6", model)
|
||||
require.Contains(t, prompt, "You are helpful")
|
||||
require.Contains(t, prompt, "Hello")
|
||||
require.Contains(t, prompt, "Hi there!")
|
||||
require.Contains(t, prompt, "How are you?")
|
||||
}
|
||||
|
||||
// TestExtractUsageFromTrajectory verifies token usage extraction
|
||||
func TestExtractUsageFromTrajectory(t *testing.T) {
|
||||
resp := `{
|
||||
"trajectory": {
|
||||
"steps": [{
|
||||
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
|
||||
"status": "CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse": {"response": "OK"},
|
||||
"metadata": {
|
||||
"modelUsage": {
|
||||
"inputTokens": "150",
|
||||
"outputTokens": "5"
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`
|
||||
usage := extractUsageFromTrajectory([]byte(resp))
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 150, usage["promptTokenCount"])
|
||||
require.Equal(t, 5, usage["candidatesTokenCount"])
|
||||
require.Equal(t, 155, usage["totalTokenCount"])
|
||||
}
|
||||
|
||||
// TestSSEChunkFormat verifies the Gemini SSE output format
|
||||
func TestSSEChunkFormat(t *testing.T) {
|
||||
chunk := buildGeminiSSEChunk("Hello world")
|
||||
require.True(t, len(chunk) > 0)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"text":"Hello world"`)
|
||||
require.Contains(t, chunk, `"role":"model"`)
|
||||
require.True(t, chunk[len(chunk)-2:] == "\n\n")
|
||||
|
||||
// Verify it's valid JSON after stripping "data: " prefix
|
||||
jsonStr := chunk[len("data: ") : len(chunk)-2]
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||
require.NoError(t, err)
|
||||
response := parsed["response"].(map[string]any)
|
||||
candidates := response["candidates"].([]any)
|
||||
require.Len(t, candidates, 1)
|
||||
}
|
||||
|
||||
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
|
||||
func TestSSEFinalChunkFormat(t *testing.T) {
|
||||
usage := map[string]any{
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150,
|
||||
}
|
||||
chunk := buildGeminiSSEFinalChunk(usage)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"finishReason":"STOP"`)
|
||||
require.Contains(t, chunk, `"usageMetadata"`)
|
||||
}
|
||||
|
||||
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
|
||||
var getCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
|
||||
getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(pr)
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
|
||||
require.Contains(t, string(body), "hello from ls")
|
||||
}
|
||||
|
||||
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
|
||||
func TestRequestHasToolsEdgeCases(t *testing.T) {
|
||||
// null tools
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
|
||||
// tools with empty function declarations
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
|
||||
// deeply nested wrapped format
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
|
||||
}
|
||||
|
||||
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
var getCalls atomic.Int32
|
||||
var sendBodiesMu sync.Mutex
|
||||
var sendBodies []map[string]any
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
var payload map[string]any
|
||||
err := json.NewDecoder(r.Body).Decode(&payload)
|
||||
require.NoError(t, err)
|
||||
sendBodiesMu.Lock()
|
||||
sendBodies = append(sendBodies, payload)
|
||||
sendBodiesMu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
call := getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
text := "hello from ls"
|
||||
if call > 1 {
|
||||
text = "follow up from ls"
|
||||
}
|
||||
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body2), `"text":"follow up from ls"`)
|
||||
|
||||
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
|
||||
require.Equal(t, int32(2), sendCalls.Load())
|
||||
|
||||
sendBodiesMu.Lock()
|
||||
require.Len(t, sendBodies, 2)
|
||||
firstSend := sendBodies[0]
|
||||
sendBodiesMu.Unlock()
|
||||
|
||||
require.Equal(t, "cid-1", firstSend["cascadeId"])
|
||||
require.Equal(t, false, firstSend["blocking"])
|
||||
metadata, ok := firstSend["metadata"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "antigravity", metadata["ideName"])
|
||||
require.Equal(t, "1.107.0", metadata["ideVersion"])
|
||||
require.NotContains(t, firstSend, "clientType")
|
||||
require.NotContains(t, firstSend, "messageOrigin")
|
||||
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
require.Len(t, cascadeConfig, 1)
|
||||
}
|
||||
|
||||
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "ok", string(body2))
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
require.Equal(t, int32(1), startCalls.Load())
|
||||
require.Equal(t, int32(1), sendCalls.Load())
|
||||
}
|
||||
|
||||
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp, err := upstream.Do(req, "", 42, 1)
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, errLSModelMapPending)
|
||||
require.Equal(t, int32(0), startCalls.Load())
|
||||
require.Equal(t, int32(0), sendCalls.Load())
|
||||
require.Equal(t, 0, fallback.doCalls)
|
||||
}
|
||||
|
||||
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
|
||||
type recordingUpstreamWithCallback struct {
|
||||
recordingUpstream
|
||||
onDo func(req *http.Request)
|
||||
}
|
||||
|
||||
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
|
||||
if r.onDo != nil {
|
||||
r.onDo(req)
|
||||
}
|
||||
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
@ -1,920 +0,0 @@
|
||||
// Package lspool provides a mock Extension Server that the LS binary connects
|
||||
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
|
||||
// using connectNodeAdapter. We replicate that protocol here.
|
||||
//
|
||||
// Protocol details (from extension.js source):
|
||||
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
|
||||
// - Auth: x-codeium-csrf-token header on every request
|
||||
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
|
||||
// OR application/connect+proto (with 5-byte envelope)
|
||||
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
|
||||
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
|
||||
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
|
||||
//
|
||||
// The LS sends requests with content-type "application/connect+proto" for BOTH
|
||||
// unary and streaming RPCs. ConnectRPC's content-type regex:
|
||||
//
|
||||
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
|
||||
//
|
||||
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
|
||||
// However the LS Go client uses the Connect protocol client which always sends
|
||||
// "application/proto" for unary and "application/connect+proto" for streaming.
|
||||
//
|
||||
// We detect the RPC kind from the URL path and respond accordingly.
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Proto helpers — hand-encode minimal proto messages so we don't
|
||||
// need to import the full generated proto package.
|
||||
// ============================================================
|
||||
|
||||
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
|
||||
func encodeProtoString(fieldNum int, val string) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, []byte(val)...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
|
||||
func encodeProtoBytes(fieldNum int, val []byte) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, val...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoVarint writes a proto varint field (wire type 0).
|
||||
func encodeProtoVarint(fieldNum int, val uint64) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 0))
|
||||
v := encodeVarint(val)
|
||||
out := make([]byte, 0, len(tag)+len(v))
|
||||
out = append(out, tag...)
|
||||
out = append(out, v...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBool writes a proto bool field.
|
||||
func encodeProtoBool(fieldNum int, val bool) []byte {
|
||||
v := uint64(0)
|
||||
if val {
|
||||
v = 1
|
||||
}
|
||||
return encodeProtoVarint(fieldNum, v)
|
||||
}
|
||||
|
||||
func encodeVarint(v uint64) []byte {
|
||||
buf := make([]byte, binary.MaxVarintLen64)
|
||||
n := binary.PutUvarint(buf, v)
|
||||
return buf[:n]
|
||||
}
|
||||
|
||||
// decodeProtoString extracts a string field from raw proto bytes.
|
||||
func decodeProtoString(data []byte, targetField int) string {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
if i >= len(data) {
|
||||
break
|
||||
}
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
|
||||
switch wireType {
|
||||
case 0: // varint
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
case 2: // length-delimited
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
if fieldNum == targetField {
|
||||
end := i + int(length)
|
||||
if end > len(data) {
|
||||
return ""
|
||||
}
|
||||
return string(data[i:end])
|
||||
}
|
||||
i += int(length)
|
||||
case 1: // 64-bit
|
||||
i += 8
|
||||
case 5: // 32-bit
|
||||
i += 4
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// ConnectRPC envelope helpers
|
||||
// ============================================================
|
||||
|
||||
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
|
||||
// 1 byte flags + 4 byte big-endian length + payload
|
||||
func connectEnvelope(flags byte, payload []byte) []byte {
|
||||
frame := make([]byte, 5+len(payload))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
||||
copy(frame[5:], payload)
|
||||
return frame
|
||||
}
|
||||
|
||||
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
|
||||
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
|
||||
func connectEndOfStream() []byte {
|
||||
trailer := []byte("{}")
|
||||
return connectEnvelope(0x02, trailer)
|
||||
}
|
||||
|
||||
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
|
||||
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
|
||||
func unwrapConnectEnvelope(body []byte) []byte {
|
||||
if len(body) < 5 {
|
||||
return body
|
||||
}
|
||||
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
|
||||
if body[0] > 0x02 {
|
||||
return body // Not envelope-framed, return raw
|
||||
}
|
||||
plen := binary.BigEndian.Uint32(body[1:5])
|
||||
if int(plen)+5 > len(body) {
|
||||
return body // Length mismatch, return raw
|
||||
}
|
||||
return body[5 : 5+plen]
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OAuthTokenInfo proto builder
|
||||
// ============================================================
|
||||
|
||||
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
|
||||
//
|
||||
// message OAuthTokenInfo {
|
||||
// string access_token = 1;
|
||||
// string token_type = 2;
|
||||
// string refresh_token = 3;
|
||||
// google.protobuf.Timestamp expiry = 4;
|
||||
// bool is_gcp_tos = 6;
|
||||
// }
|
||||
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
var buf []byte
|
||||
buf = append(buf, encodeProtoString(1, accessToken)...)
|
||||
buf = append(buf, encodeProtoString(2, "Bearer")...)
|
||||
if refreshToken != "" {
|
||||
buf = append(buf, encodeProtoString(3, refreshToken)...)
|
||||
}
|
||||
// Use real expiry if provided, otherwise default to 1 hour from now
|
||||
expiry := expiresAt
|
||||
if expiry.IsZero() {
|
||||
expiry = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
ts := ×tamppb.Timestamp{
|
||||
Seconds: expiry.Unix(),
|
||||
}
|
||||
tsBytes, _ := proto.Marshal(ts)
|
||||
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
|
||||
buf = append(buf, encodeProtoBool(6, true)...)
|
||||
return buf
|
||||
}
|
||||
|
||||
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
|
||||
//
|
||||
// message Topic { map<string, Row> data = 1; }
|
||||
// message Row { string value = 1; int64 e_tag = 2; }
|
||||
//
|
||||
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
|
||||
// base64(toBinary(OAuthTokenInfo)).
|
||||
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
|
||||
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, tokenB64)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
|
||||
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
|
||||
// Topic: data map entries use field 1
|
||||
var topic []byte
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
|
||||
return topic
|
||||
}
|
||||
|
||||
func buildPrimitiveBoolBinary(val bool) []byte {
|
||||
// Primitive.bool_value is field 13 in the proto definition
|
||||
return encodeProtoBool(13, val)
|
||||
}
|
||||
|
||||
func buildPrimitiveInt32Binary(val int32) []byte {
|
||||
// Primitive.int32_value is field 3 in the proto definition
|
||||
return encodeProtoVarint(3, uint64(uint32(val)))
|
||||
}
|
||||
|
||||
func encodeUSSBinaryValue(value []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(value)
|
||||
}
|
||||
|
||||
func encodeUSSPrimitiveBoolValue(val bool) string {
|
||||
return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val))
|
||||
}
|
||||
|
||||
func encodeUSSPrimitiveInt32Value(val int32) string {
|
||||
return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val))
|
||||
}
|
||||
|
||||
func buildUSSTopicRow(key string, value string) []byte {
|
||||
row := buildUSSRowBinary(value)
|
||||
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, key)...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
return entry
|
||||
}
|
||||
|
||||
func buildUSSRowBinary(value string) []byte {
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, value)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
return row
|
||||
}
|
||||
|
||||
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
entries := make([][]byte, 0, 3)
|
||||
entries = append(entries, buildUSSTopicRow(
|
||||
useAICreditsSentinelKey,
|
||||
encodeUSSPrimitiveBoolValue(info.UseAICredits),
|
||||
))
|
||||
// JS protocol: useAICreditsSentinelKey carries the toggle state.
|
||||
// availableCreditsSentinelKey is only present when credits are enabled.
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits)))
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum)))
|
||||
|
||||
var topic []byte
|
||||
for _, entry := range entries {
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
}
|
||||
return topic
|
||||
}
|
||||
|
||||
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
|
||||
func buildEmptyTopic() []byte {
|
||||
return []byte{} // Empty message = no map entries
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// UnifiedStateSyncUpdate builder
|
||||
// ============================================================
|
||||
|
||||
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
|
||||
//
|
||||
// message UnifiedStateSyncUpdate {
|
||||
// oneof update_type {
|
||||
// Topic initial_state = 1;
|
||||
// AppliedUpdate applied_update = 2;
|
||||
// }
|
||||
// }
|
||||
func buildInitialStateUpdate(topicData []byte) []byte {
|
||||
return encodeProtoBytes(1, topicData)
|
||||
}
|
||||
|
||||
func buildAppliedUpdate(key string, row []byte) []byte {
|
||||
var applied []byte
|
||||
applied = append(applied, encodeProtoString(1, key)...)
|
||||
if len(row) > 0 {
|
||||
applied = append(applied, encodeProtoBytes(2, row)...)
|
||||
}
|
||||
return encodeProtoBytes(2, applied)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MockExtensionServer
|
||||
// ============================================================
|
||||
|
||||
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
|
||||
// Language Server binary connects to. It implements just enough of the
|
||||
// ExtensionServerService to keep the LS operational.
|
||||
type MockExtensionServer struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
port int
|
||||
csrf string
|
||||
mu sync.RWMutex
|
||||
tokens map[string]*TokenInfo // account_id -> token info
|
||||
credits map[string]*ModelCreditsInfo // account_id -> model credits info
|
||||
subscribers map[string]map[int]*stateSubscriber
|
||||
nextSubID int
|
||||
lastAccountID string
|
||||
logger *slog.Logger
|
||||
|
||||
// Trajectory callback — when LS pushes trajectory updates, we forward them
|
||||
onTrajectoryUpdate func(topic, key string, data []byte)
|
||||
}
|
||||
|
||||
// TokenInfo holds OAuth token details for an account.
|
||||
type TokenInfo struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
|
||||
}
|
||||
|
||||
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
|
||||
type ModelCreditsInfo struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmountForUsage *int32
|
||||
}
|
||||
|
||||
type stateSubscriber struct {
|
||||
id int
|
||||
accountID string
|
||||
topic string
|
||||
updates chan []byte
|
||||
}
|
||||
|
||||
const (
|
||||
useAICreditsSentinelKey = "useAICreditsSentinelKey"
|
||||
availableCreditsSentinelKey = "availableCreditsSentinelKey"
|
||||
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
|
||||
defaultMinimumCreditAmountForUsage = int32(50)
|
||||
)
|
||||
|
||||
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
|
||||
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
m := &MockExtensionServer{
|
||||
listener: listener,
|
||||
port: listener.Addr().(*net.TCPAddr).Port,
|
||||
csrf: csrf,
|
||||
tokens: make(map[string]*TokenInfo),
|
||||
credits: make(map[string]*ModelCreditsInfo),
|
||||
subscribers: make(map[string]map[int]*stateSubscriber),
|
||||
logger: slog.Default().With("component", "mock-ext-server"),
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
extService := "/exa.extension_server_pb.ExtensionServerService/"
|
||||
|
||||
// Register all RPCs the LS calls on the Extension Server.
|
||||
// Unary RPCs — return application/proto
|
||||
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
|
||||
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
|
||||
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
|
||||
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
|
||||
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
|
||||
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
|
||||
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
|
||||
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
|
||||
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
|
||||
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
|
||||
|
||||
// Server-streaming RPCs — return application/connect+proto
|
||||
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
|
||||
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
|
||||
|
||||
// Catch-all for any unregistered RPCs
|
||||
mux.HandleFunc("/", m.handleCatchAll)
|
||||
|
||||
m.server = &http.Server{Handler: mux}
|
||||
|
||||
go func() {
|
||||
if err := m.server.Serve(listener); err != http.ErrServerClosed {
|
||||
m.logger.Error("extension server error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Port returns the listening port.
|
||||
func (m *MockExtensionServer) Port() int {
|
||||
return m.port
|
||||
}
|
||||
|
||||
// SetToken sets the OAuth token for an account.
|
||||
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
|
||||
m.mu.Lock()
|
||||
m.tokens[accountID] = info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
|
||||
}
|
||||
|
||||
// SetModelCredits sets the uss-modelCredits state for an account.
|
||||
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
copyInfo := *info
|
||||
m.mu.Lock()
|
||||
m.credits[accountID] = ©Info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...)
|
||||
}
|
||||
|
||||
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
|
||||
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
|
||||
m.onTrajectoryUpdate = fn
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.tokens[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.tokens {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.credits[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.credits {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
|
||||
if accountID != "" {
|
||||
if info := m.tokens[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentTokenLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
|
||||
if accountID != "" {
|
||||
if info := m.credits[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentModelCreditsLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
|
||||
topicSubs := m.subscribers[topic]
|
||||
if len(topicSubs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*stateSubscriber, 0, len(topicSubs))
|
||||
for _, sub := range topicSubs {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
|
||||
continue
|
||||
}
|
||||
out = append(out, sub)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
|
||||
for _, sub := range subscribers {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
for _, update := range updates {
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
payload := append([]byte(nil), update...)
|
||||
select {
|
||||
case sub.updates <- payload:
|
||||
default:
|
||||
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
updates := make([][]byte, 0, 3)
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
useAICreditsSentinelKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)),
|
||||
))
|
||||
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
availableCreditsSentinelKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)),
|
||||
))
|
||||
} else {
|
||||
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
minimumCreditAmountForUsageKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)),
|
||||
))
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// Close shuts down the server.
|
||||
func (m *MockExtensionServer) Close() error {
|
||||
return m.server.Close()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Middleware
|
||||
// ============================================================
|
||||
|
||||
type unaryHandler func(body []byte) []byte
|
||||
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
|
||||
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// CSRF check
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
|
||||
// The LS might send with envelope framing (application/connect+proto)
|
||||
// or without (application/proto). Detect and unwrap.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
protoBody := body
|
||||
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
|
||||
protoBody = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
|
||||
|
||||
responseProto := handler(protoBody)
|
||||
|
||||
// Respond with proper unary ConnectRPC content-type.
|
||||
// If the request used "connect+proto", the response should be "application/proto"
|
||||
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
if len(responseProto) > 0 {
|
||||
w.Write(responseProto)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
|
||||
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
return
|
||||
}
|
||||
|
||||
// Unwrap envelope from request
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
|
||||
body = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
|
||||
|
||||
// Set streaming response content-type
|
||||
w.Header().Set("Content-Type", "application/connect+proto")
|
||||
w.WriteHeader(200)
|
||||
|
||||
handler(body, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
|
||||
token := r.Header.Get("x-codeium-csrf-token")
|
||||
if m.csrf != "" && token != m.csrf {
|
||||
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(403)
|
||||
w.Write([]byte("Invalid CSRF token"))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Unary RPC Handlers — each receives raw proto request body,
|
||||
// returns raw proto response body.
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
|
||||
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
|
||||
// We just log the ports — they're informational.
|
||||
m.logger.Info("LanguageServerStarted",
|
||||
"body_len", len(body))
|
||||
// Return empty LanguageServerStartedResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
|
||||
// Return empty HeartbeatResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
|
||||
// GetSecretValueRequest: key = field 1
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("GetSecretValue", "key", key)
|
||||
|
||||
m.mu.RLock()
|
||||
var token string
|
||||
if info := m.currentTokenLocked(); info != nil {
|
||||
token = info.AccessToken
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// GetSecretValueResponse: value = field 1
|
||||
if token != "" {
|
||||
return encodeProtoString(1, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("StoreSecretValue", "key", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
|
||||
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
|
||||
return encodeProtoBool(1, false)
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
|
||||
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
|
||||
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
|
||||
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
|
||||
|
||||
// Extract topic name from the embedded UpdateRequest
|
||||
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
|
||||
// We need to dig into the nested message to get topic_name
|
||||
if m.onTrajectoryUpdate != nil {
|
||||
// For now, just notify that an update was pushed
|
||||
m.onTrajectoryUpdate("", "", body)
|
||||
}
|
||||
|
||||
// Return empty PushUnifiedStateSyncUpdateResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
|
||||
m.logger.Debug("RecordError", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
|
||||
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onDefault(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Streaming RPC Handlers
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
|
||||
topic := decodeProtoString(body, 1)
|
||||
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
m.logger.Error("ResponseWriter does not support Flush")
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
accountID := m.lastAccountID
|
||||
subID := m.nextSubID
|
||||
m.nextSubID++
|
||||
sub := &stateSubscriber{
|
||||
id: subID,
|
||||
accountID: accountID,
|
||||
topic: topic,
|
||||
updates: make(chan []byte, 16),
|
||||
}
|
||||
if m.subscribers[topic] == nil {
|
||||
m.subscribers[topic] = make(map[int]*stateSubscriber)
|
||||
}
|
||||
m.subscribers[topic][subID] = sub
|
||||
|
||||
// Build initial state based on topic
|
||||
var topicData []byte
|
||||
switch topic {
|
||||
case "uss-oauth":
|
||||
tokenInfo := m.tokenForAccountLocked(accountID)
|
||||
if tokenInfo != nil {
|
||||
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
case "uss-modelCredits":
|
||||
creditsInfo := m.creditsForAccountLocked(accountID)
|
||||
if creditsInfo != nil {
|
||||
topicData = buildUSSTopicWithModelCredits(creditsInfo)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
default:
|
||||
// For all other topics (browserPreferences, enterprisePreferences, etc.),
|
||||
// return empty topic data.
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
if topicSubs := m.subscribers[topic]; topicSubs != nil {
|
||||
delete(topicSubs, subID)
|
||||
if len(topicSubs) == 0 {
|
||||
delete(m.subscribers, topic)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Send initial state as envelope-framed message
|
||||
initialUpdate := buildInitialStateUpdate(topicData)
|
||||
frame := connectEnvelope(0x00, initialUpdate)
|
||||
w.Write(frame)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
|
||||
return
|
||||
case update := <-sub.updates:
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
|
||||
// Send end-of-stream immediately — we don't execute commands
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.Write(connectEndOfStream())
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Catch-all handler
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
|
||||
|
||||
// Drain request body
|
||||
io.ReadAll(r.Body)
|
||||
|
||||
// Determine if this is likely a unary or streaming request based on content-type.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+") {
|
||||
// Could be streaming — respond with unary proto to be safe
|
||||
// (unary Connect requests can also use connect+ prefix in some client impls)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,376 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"SSL_CERT_FILE=/custom/ca.pem",
|
||||
"SSL_CERT_DIR=/custom/certs",
|
||||
}, "/opt/antigravity", "")
|
||||
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
|
||||
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
|
||||
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
|
||||
}
|
||||
|
||||
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"HTTPS_PROXY=http://old-proxy:8080",
|
||||
"HTTP_PROXY=http://old-proxy:8080",
|
||||
"ALL_PROXY=socks5://old-proxy:1080",
|
||||
"https_proxy=http://old-proxy:8080",
|
||||
"http_proxy=http://old-proxy:8080",
|
||||
"all_proxy=socks5://old-proxy:1080",
|
||||
}, "/opt/antigravity", "")
|
||||
|
||||
require.Contains(t, env, "HTTPS_PROXY=")
|
||||
require.Contains(t, env, "HTTP_PROXY=")
|
||||
require.Contains(t, env, "ALL_PROXY=")
|
||||
require.Contains(t, env, "https_proxy=")
|
||||
require.Contains(t, env, "http_proxy=")
|
||||
require.Contains(t, env, "all_proxy=")
|
||||
}
|
||||
|
||||
func TestShortAccountID(t *testing.T) {
|
||||
require.Equal(t, "9", shortAccountID("9"))
|
||||
require.Equal(t, "12345678", shortAccountID("12345678"))
|
||||
require.Equal(t, "12345678", shortAccountID("123456789"))
|
||||
}
|
||||
|
||||
func TestFrameConnectMessage(t *testing.T) {
|
||||
framed := frameConnectMessage([]byte(`{"x":1}`))
|
||||
require.Len(t, framed, 5+len(`{"x":1}`))
|
||||
require.Equal(t, byte(0), framed[0])
|
||||
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
|
||||
require.Equal(t, `{"x":1}`, string(framed[5:]))
|
||||
}
|
||||
|
||||
func TestConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("hello")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
require.Len(t, env, 5+len(payload))
|
||||
require.Equal(t, byte(0x00), env[0])
|
||||
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
|
||||
require.Equal(t, "hello", string(env[5:]))
|
||||
}
|
||||
|
||||
func TestUnwrapConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("test data")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
unwrapped := unwrapConnectEnvelope(env)
|
||||
require.Equal(t, payload, unwrapped)
|
||||
short := []byte{1, 2}
|
||||
require.Equal(t, short, unwrapConnectEnvelope(short))
|
||||
}
|
||||
|
||||
func TestExtractPromptAndModel(t *testing.T) {
|
||||
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "hello world", prompt)
|
||||
require.Equal(t, "gemini-2.5-pro", model)
|
||||
|
||||
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
|
||||
prompt2, _ := extractPromptAndModel([]byte(body2))
|
||||
require.Equal(t, "test prompt", prompt2)
|
||||
}
|
||||
|
||||
func TestResolveModelEnum(t *testing.T) {
|
||||
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
|
||||
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
|
||||
require.True(t, resolveModelEnum("unknown-model") > 0)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("models/gemini-2.5-flash")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("claude-sonnet-4-6")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
|
||||
fallback := &recordingUpstream{}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, fallback)
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseText(t *testing.T) {
|
||||
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
|
||||
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse":{"response":"Hello world"}}
|
||||
]}}`
|
||||
text, generating, status := extractPlannerResponseText([]byte(resp))
|
||||
require.Equal(t, "Hello world", text)
|
||||
require.False(t, generating)
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
|
||||
resp := `{
|
||||
"status":"CASCADE_RUN_STATUS_IDLE",
|
||||
"trajectory":{
|
||||
"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
|
||||
],
|
||||
"executorMetadata":{
|
||||
"terminationReason":"ERROR",
|
||||
"errorDetails":{
|
||||
"errorCode":429,
|
||||
"shortError":"Model quota reached",
|
||||
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
state := extractPlannerResponseState([]byte(resp))
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
|
||||
require.False(t, state.Generating)
|
||||
require.Empty(t, state.Text)
|
||||
require.Contains(t, state.ErrorMessage, "Model quota reached")
|
||||
require.Contains(t, state.ErrorMessage, "quota will reset after")
|
||||
}
|
||||
|
||||
func TestBuildGeminiSSEChunk(t *testing.T) {
|
||||
sse := buildGeminiSSEChunk("hello")
|
||||
require.Contains(t, sse, "data: ")
|
||||
require.Contains(t, sse, `"text":"hello"`)
|
||||
require.Contains(t, sse, `"role":"model"`)
|
||||
require.True(t, strings.HasSuffix(sse, "\n\n"))
|
||||
}
|
||||
|
||||
func TestRequestHasTools(t *testing.T) {
|
||||
// Wrapped format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
|
||||
|
||||
// Direct format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
|
||||
|
||||
// No tools
|
||||
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
|
||||
|
||||
// Empty tools array
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
|
||||
}
|
||||
|
||||
func TestCurrentLSStrategy(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
|
||||
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
|
||||
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
||||
}
|
||||
|
||||
func TestIsPermanentModelMappingError(t *testing.T) {
|
||||
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
|
||||
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
|
||||
}
|
||||
|
||||
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
|
||||
pool := &Pool{
|
||||
instances: map[string][]*Instance{
|
||||
"9": {
|
||||
{AccountID: "9", Replica: 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
inst := pool.instances["9"][0]
|
||||
inst.SetModelMappingReady(true)
|
||||
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
|
||||
|
||||
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
|
||||
|
||||
require.False(t, inst.HasModelMappingReady())
|
||||
require.False(t, inst.HasModelMappingUnavailable())
|
||||
require.Empty(t, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
|
||||
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
|
||||
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
|
||||
require.False(t, shouldFallbackDirect(errLSModelMapPending))
|
||||
}
|
||||
|
||||
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
|
||||
require.Equal(t, 3, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
}
|
||||
|
||||
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {
|
||||
{AccountID: "acc-1", Replica: 0, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 1, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 2, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 3, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 4, healthy: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routingKey := "acc-1:user-a:session-1"
|
||||
slot := replicaSlotIndex(routingKey, pool.replicaCount())
|
||||
inst := pool.Get("acc-1", routingKey)
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, slot, inst.Replica)
|
||||
}
|
||||
|
||||
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
|
||||
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
|
||||
atomic.StoreInt64(&busy.inflight, 4)
|
||||
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
|
||||
atomic.StoreInt64(&idle.inflight, 1)
|
||||
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {busy, idle},
|
||||
},
|
||||
}
|
||||
|
||||
inst := pool.Get("acc-1", "")
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, 1, inst.Replica)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, attempts)
|
||||
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
calls := 0
|
||||
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return errors.New("not ready")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, attempts)
|
||||
require.Equal(t, 3, calls)
|
||||
}
|
||||
|
||||
func TestDecideJSParityRoute(t *testing.T) {
|
||||
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsed, err := parseGeminiRequest(body)
|
||||
require.NoError(t, err)
|
||||
decision := decideJSParityRoute(parsed, body)
|
||||
require.True(t, decision.UseLS)
|
||||
|
||||
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
|
||||
parsedImage, err := parseGeminiRequest(imageBody)
|
||||
require.NoError(t, err)
|
||||
decisionImage := decideJSParityRoute(parsedImage, imageBody)
|
||||
require.False(t, decisionImage.UseLS)
|
||||
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
|
||||
|
||||
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsedNoSession, err := parseGeminiRequest(noSessionBody)
|
||||
require.NoError(t, err)
|
||||
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
|
||||
}
|
||||
|
||||
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(userNamespaceHeader, "tenant-a")
|
||||
req.Header.Set("Authorization", "Bearer oauth-token")
|
||||
|
||||
nsWithExplicit := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithExplicit)
|
||||
|
||||
req.Header.Del(userNamespaceHeader)
|
||||
nsWithAuth := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithAuth)
|
||||
require.NotEqual(t, nsWithExplicit, nsWithAuth)
|
||||
}
|
||||
|
||||
func TestConversationPrefixEqual(t *testing.T) {
|
||||
prefix := []geminiConversationTurn{
|
||||
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
|
||||
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
|
||||
}
|
||||
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
|
||||
Role: "user",
|
||||
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
|
||||
})
|
||||
require.True(t, conversationPrefixEqual(full, prefix))
|
||||
require.False(t, conversationPrefixEqual(prefix, full))
|
||||
}
|
||||
|
||||
func TestSystemTextCompatible(t *testing.T) {
|
||||
require.True(t, systemTextCompatible("You are helpful", ""))
|
||||
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("You are helpful", "You are different"))
|
||||
}
|
||||
|
||||
type recordingUpstream struct {
|
||||
doCalls int
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
r.doCalls++
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return r.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
@ -1,268 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type lsProxyBridge struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
url string
|
||||
upstream string
|
||||
}
|
||||
|
||||
type lsProxyBridgeManager struct {
|
||||
mu sync.Mutex
|
||||
bridges map[string]*lsProxyBridge
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
|
||||
bridges: make(map[string]*lsProxyBridge),
|
||||
logger: slog.Default().With("component", "lspool-proxy-bridge"),
|
||||
}
|
||||
|
||||
var (
|
||||
lsProxyBridgeDialTimeout = 10 * time.Second
|
||||
lsProxyBridgeProbeTargets = []string{
|
||||
"cloudcode-pa.googleapis.com:443",
|
||||
"oauthaccountmanager.googleapis.com:443",
|
||||
}
|
||||
)
|
||||
|
||||
func prepareLSProxyURL(raw string) (string, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
return normalized, nil
|
||||
case "socks5", "socks5h":
|
||||
return globalLSProxyBridgeManager.ensure(normalized, parsed)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if bridge := m.bridges[key]; bridge != nil {
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
bridge, err := newLSProxyBridge(upstream, m.logger)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
m.bridges[key] = bridge
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) closeAll() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for key, bridge := range m.bridges {
|
||||
if bridge != nil {
|
||||
_ = bridge.server.Close()
|
||||
_ = bridge.listener.Close()
|
||||
}
|
||||
delete(m.bridges, key)
|
||||
}
|
||||
}
|
||||
|
||||
func closeAllLSProxyBridgesForTest() {
|
||||
globalLSProxyBridgeManager.closeAll()
|
||||
}
|
||||
|
||||
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
|
||||
dialer, err := proxy.FromURL(upstream, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
|
||||
}
|
||||
|
||||
bridge := &lsProxyBridge{
|
||||
listener: listener,
|
||||
url: "http://" + listener.Addr().String(),
|
||||
upstream: upstream.Redacted(),
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
}
|
||||
bridge.server = server
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
|
||||
go bridge.probeConnectivity(dialer, logger)
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
targetAddr := strings.TrimSpace(r.Host)
|
||||
if targetAddr == "" {
|
||||
targetAddr = strings.TrimSpace(r.URL.Host)
|
||||
}
|
||||
if targetAddr == "" {
|
||||
http.Error(w, "missing target host", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
|
||||
targetAddr = net.JoinHostPort(targetAddr, "443")
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge dial failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
http.Error(w, "proxy dial failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
logger.Info("LS proxy bridge CONNECT established",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
_ = targetConn.Close()
|
||||
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
_ = targetConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if rw != nil && rw.Reader.Buffered() > 0 {
|
||||
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tunnelConns(clientConn, targetConn)
|
||||
}
|
||||
}
|
||||
|
||||
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
|
||||
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
return contextDialer.DialContext(ctx, "tcp", targetAddr)
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialer.Dial("tcp", targetAddr)
|
||||
resultCh <- dialResult{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.conn, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
|
||||
for _, targetAddr := range lsProxyBridgeProbeTargets {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
|
||||
conn, err := dialViaProxy(ctx, dialer, targetAddr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge probe failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
continue
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Info("LS proxy bridge probe succeeded",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
|
||||
var once sync.Once
|
||||
closeBoth := func() {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(targetConn, clientConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(clientConn, targetConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
}
|
||||
|
||||
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
|
||||
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
||||
}
|
||||
@ -1,193 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "http://proxy.example.com:8080", got)
|
||||
}
|
||||
|
||||
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
targetAddr, closeTarget := startBridgeEchoServer(t)
|
||||
defer closeTarget()
|
||||
|
||||
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
|
||||
defer closeSOCKS()
|
||||
|
||||
bridgeURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
|
||||
|
||||
// Same SOCKS upstream should reuse the same local bridge.
|
||||
reusedURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bridgeURL, reusedURL)
|
||||
|
||||
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
|
||||
conn, err := net.Dial("tcp", bridgeAddr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, err := readConnectResponse(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
_, err = conn.Write([]byte("ping"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reply := make([]byte, 4)
|
||||
_, err = io.ReadFull(reader, reply)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", string(reply))
|
||||
}
|
||||
|
||||
func startBridgeEchoServer(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return
|
||||
}
|
||||
if string(buf) == "ping" {
|
||||
_, _ = c.Write([]byte("pong"))
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleBridgeSOCKS5Conn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return "socks5://" + ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func handleBridgeSOCKS5Conn(conn net.Conn) {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
methods := make([]byte, int(header[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte{0x05, 0x00})
|
||||
|
||||
reqHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
|
||||
if !ok {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
targetConn, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
tunnelConns(conn, targetConn)
|
||||
}
|
||||
|
||||
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
|
||||
switch atyp {
|
||||
case 0x01:
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
case 0x03:
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
buf := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(buf), true
|
||||
case 0x04:
|
||||
buf := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@ -1,138 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
)
|
||||
|
||||
type lsLaunchPlan struct {
|
||||
cmd *exec.Cmd
|
||||
effectiveProxyURL string
|
||||
proxyMode string
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan := &lsLaunchPlan{
|
||||
cmd: exec.Command(binPath, args...),
|
||||
proxyMode: "direct",
|
||||
}
|
||||
|
||||
if parsed == nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
plan.effectiveProxyURL = normalized
|
||||
plan.proxyMode = "env-http-proxy"
|
||||
return plan, nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
|
||||
cfgPath, err := writeProxychainsConfig(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
|
||||
plan.proxyMode = "proxychains4"
|
||||
plan.cleanup = func() {
|
||||
_ = os.Remove(cfgPath)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
effectiveProxyURL, err := prepareLSProxyURL(normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.effectiveProxyURL = effectiveProxyURL
|
||||
plan.proxyMode = "http-connect-bridge"
|
||||
return plan, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
content, err := buildProxychainsConfig(proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create proxychains config: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.WriteString(content); err != nil {
|
||||
_ = file.Close()
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("write proxychains config: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("close proxychains config: %w", err)
|
||||
}
|
||||
|
||||
return file.Name(), nil
|
||||
}
|
||||
|
||||
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
if proxyURL == nil {
|
||||
return "", fmt.Errorf("proxy url is nil")
|
||||
}
|
||||
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
|
||||
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(proxyURL.Hostname())
|
||||
port := strings.TrimSpace(proxyURL.Port())
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("proxy host is empty")
|
||||
}
|
||||
if port == "" {
|
||||
port = "1080"
|
||||
}
|
||||
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
|
||||
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("strict_chain\n")
|
||||
builder.WriteString("proxy_dns\n")
|
||||
builder.WriteString("remote_dns_subnet 224\n")
|
||||
builder.WriteString("tcp_connect_time_out 8000\n")
|
||||
builder.WriteString("tcp_read_time_out 15000\n")
|
||||
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
|
||||
builder.WriteString("localnet ::1/128\n")
|
||||
builder.WriteString("[ProxyList]\n")
|
||||
builder.WriteString("socks5 ")
|
||||
builder.WriteString(host)
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(port)
|
||||
if username != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(username)
|
||||
if password != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(password)
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
return builder.String(), nil
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := buildProxychainsConfig(proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, cfg, "proxy_dns\n")
|
||||
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
|
||||
require.Contains(t, cfg, "localnet ::1/128\n")
|
||||
require.Contains(t, cfg, "[ProxyList]\n")
|
||||
require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n")
|
||||
}
|
||||
|
||||
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = buildProxychainsConfig(proxyURL)
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "whitespace"))
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: i.Address,
|
||||
Path: path,
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Set("service", service)
|
||||
values.Set("method", method)
|
||||
values.Set("mode", mode)
|
||||
if i.routingKey != "" {
|
||||
values.Set("routing_key", i.routingKey)
|
||||
}
|
||||
base.RawQuery = values.Encode()
|
||||
return base.String(), nil
|
||||
}
|
||||
|
||||
func marshalWorkerJSONBody(input any) ([]byte, error) {
|
||||
if input == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,680 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
lsWorkerManagedByLabel = "managed-by"
|
||||
lsWorkerManagedByValue = "sub2api"
|
||||
lsWorkerAccountLabel = "account_id"
|
||||
lsWorkerProxyHashLabel = "proxy_hash"
|
||||
lsWorkerImageTagLabel = "image_tag"
|
||||
lsWorkerControlPort = 18081
|
||||
)
|
||||
|
||||
type workerManagerConfig struct {
|
||||
Image string
|
||||
Network string
|
||||
DockerSocket string
|
||||
IdleTTL time.Duration
|
||||
MaxActive int
|
||||
StartupTimeout time.Duration
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
type dockerClient interface {
|
||||
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
|
||||
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
|
||||
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
|
||||
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
|
||||
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
|
||||
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type workerManager struct {
|
||||
cfg workerManagerConfig
|
||||
docker dockerClient
|
||||
http *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
workers map[string]*workerHandle
|
||||
state map[string]*workerAccountState
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type workerHandle struct {
|
||||
Key string
|
||||
AccountID string
|
||||
ProxyURL string
|
||||
ProxyHash string
|
||||
ContainerID string
|
||||
Container string
|
||||
Address string
|
||||
AuthToken string
|
||||
LastUsed time.Time
|
||||
LastStateSHA string
|
||||
}
|
||||
|
||||
type workerAccountState struct {
|
||||
HasToken bool `json:"has_token"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
HasModelCredits bool `json:"has_model_credits"`
|
||||
UseAICredits bool `json:"use_ai_credits"`
|
||||
AvailableCredits *int32 `json:"available_credits,omitempty"`
|
||||
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
|
||||
}
|
||||
|
||||
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
managerCfg := workerManagerConfig{
|
||||
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
|
||||
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
|
||||
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
|
||||
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
|
||||
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
|
||||
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
|
||||
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
|
||||
}
|
||||
|
||||
if managerCfg.Image == "" {
|
||||
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
|
||||
}
|
||||
if managerCfg.Network == "" {
|
||||
managerCfg.Network = "sub2api-network"
|
||||
}
|
||||
if managerCfg.DockerSocket == "" {
|
||||
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
|
||||
}
|
||||
if managerCfg.IdleTTL <= 0 {
|
||||
managerCfg.IdleTTL = 15 * time.Minute
|
||||
}
|
||||
if managerCfg.MaxActive < 1 {
|
||||
managerCfg.MaxActive = 50
|
||||
}
|
||||
if managerCfg.StartupTimeout <= 0 {
|
||||
managerCfg.StartupTimeout = 45 * time.Second
|
||||
}
|
||||
if managerCfg.RequestTimeout <= 0 {
|
||||
managerCfg.RequestTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
opts := []client.Opt{client.WithAPIVersionNegotiation()}
|
||||
if managerCfg.DockerSocket != "" {
|
||||
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
|
||||
} else {
|
||||
opts = append(opts, client.FromEnv)
|
||||
}
|
||||
|
||||
dockerClient, err := client.NewClientWithOpts(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create docker client: %w", err)
|
||||
}
|
||||
|
||||
return newWorkerManager(managerCfg, dockerClient)
|
||||
}
|
||||
|
||||
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
mgr := &workerManager{
|
||||
cfg: cfg,
|
||||
docker: docker,
|
||||
http: &http.Client{
|
||||
Timeout: cfg.RequestTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConnsPerHost: 8,
|
||||
},
|
||||
},
|
||||
workers: make(map[string]*workerHandle),
|
||||
state: make(map[string]*workerAccountState),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: slog.Default().With("component", "lspool-worker-manager"),
|
||||
}
|
||||
if err := mgr.reconcileManagedContainers(ctx); err != nil {
|
||||
cancel()
|
||||
_ = docker.Close()
|
||||
return nil, err
|
||||
}
|
||||
go mgr.cleanupLoop()
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) Close() {
|
||||
m.cancel()
|
||||
|
||||
m.mu.Lock()
|
||||
workers := make([]*workerHandle, 0, len(m.workers))
|
||||
for _, handle := range m.workers {
|
||||
workers = append(workers, handle)
|
||||
}
|
||||
m.workers = make(map[string]*workerHandle)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range workers {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
if m.docker != nil {
|
||||
_ = m.docker.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) Stats() map[string]any {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return map[string]any{
|
||||
"accounts": len(m.state),
|
||||
"total": len(m.workers),
|
||||
"active": len(m.workers),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasToken = true
|
||||
state.AccessToken = accessToken
|
||||
state.RefreshToken = refreshToken
|
||||
if expiresAt.IsZero() {
|
||||
state.ExpiresAt = nil
|
||||
} else {
|
||||
ts := expiresAt.UTC()
|
||||
state.ExpiresAt = &ts
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasModelCredits = true
|
||||
state.UseAICredits = useAICredits
|
||||
state.AvailableCredits = cloneInt32Ptr(availableCredits)
|
||||
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedProxy == nil {
|
||||
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
|
||||
}
|
||||
|
||||
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
|
||||
proxyHash := proxyHash(normalizedProxy)
|
||||
workerKey := buildWorkerKey(accountID, proxyHash)
|
||||
|
||||
m.mu.Lock()
|
||||
state := cloneWorkerAccountState(m.state[accountID])
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
|
||||
}
|
||||
|
||||
handle := m.workers[workerKey]
|
||||
if handle == nil {
|
||||
if len(m.workers) >= m.cfg.MaxActive {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
|
||||
}
|
||||
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
m.workers[workerKey] = handle
|
||||
}
|
||||
handle.LastUsed = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
if err := m.waitForWorkerHealthy(handle); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.syncWorkerState(handle, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: accountID,
|
||||
Replica: replica,
|
||||
Address: handle.Address,
|
||||
client: m.http,
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
modelMapReady: 1,
|
||||
remote: true,
|
||||
workerToken: handle.AuthToken,
|
||||
routingKey: routingKey,
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.collectIdleWorkers()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) collectIdleWorkers() {
|
||||
now := time.Now()
|
||||
var expired []*workerHandle
|
||||
|
||||
m.mu.Lock()
|
||||
for key, handle := range m.workers {
|
||||
if handle == nil {
|
||||
delete(m.workers, key)
|
||||
continue
|
||||
}
|
||||
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
|
||||
continue
|
||||
}
|
||||
expired = append(expired, handle)
|
||||
delete(m.workers, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range expired {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
|
||||
args := filters.NewArgs()
|
||||
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
|
||||
|
||||
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
|
||||
All: true,
|
||||
Filters: args,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list managed ls workers: %w", err)
|
||||
}
|
||||
|
||||
for _, summary := range containers {
|
||||
handle := &workerHandle{
|
||||
ContainerID: summary.ID,
|
||||
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
|
||||
}
|
||||
if err := m.removeWorkerContainer(ctx, handle); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
|
||||
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
|
||||
authToken := generateUUID()
|
||||
|
||||
proxyHost := parsedProxy.Hostname()
|
||||
proxyPort := parsedProxy.Port()
|
||||
if proxyPort == "" {
|
||||
proxyPort = "1080"
|
||||
}
|
||||
proxyUser := parsedProxy.User.Username()
|
||||
proxyPass, _ := parsedProxy.User.Password()
|
||||
|
||||
labels := map[string]string{
|
||||
lsWorkerManagedByLabel: lsWorkerManagedByValue,
|
||||
lsWorkerAccountLabel: accountID,
|
||||
lsWorkerProxyHashLabel: proxyHash,
|
||||
lsWorkerImageTagLabel: m.cfg.Image,
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"ANTIGRAVITY_APP_ROOT=/app/ls",
|
||||
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
|
||||
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
|
||||
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
|
||||
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
|
||||
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
|
||||
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
|
||||
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
|
||||
}
|
||||
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
|
||||
env = append(env, "TZ="+tz)
|
||||
}
|
||||
|
||||
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
|
||||
Image: m.cfg.Image,
|
||||
Labels: labels,
|
||||
Env: env,
|
||||
}, &container.HostConfig{
|
||||
CapAdd: []string{"NET_ADMIN"},
|
||||
}, &network.NetworkingConfig{
|
||||
EndpointsConfig: map[string]*network.EndpointSettings{
|
||||
m.cfg.Network: {},
|
||||
},
|
||||
}, nil, containerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ls worker container: %w", err)
|
||||
}
|
||||
|
||||
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("start ls worker container: %w", err)
|
||||
}
|
||||
|
||||
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("inspect ls worker container: %w", err)
|
||||
}
|
||||
|
||||
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.logger.Info("created ls worker",
|
||||
"account", shortAccountID(accountID),
|
||||
"container", containerName,
|
||||
"address", address,
|
||||
"proxy_hash", proxyHash[:8])
|
||||
|
||||
return &workerHandle{
|
||||
Key: buildWorkerKey(accountID, proxyHash),
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: proxyHash,
|
||||
ContainerID: createResp.ID,
|
||||
Container: containerName,
|
||||
Address: address,
|
||||
AuthToken: authToken,
|
||||
LastUsed: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
|
||||
if inspect.NetworkSettings == nil {
|
||||
return "", fmt.Errorf("ls worker inspect missing network settings")
|
||||
}
|
||||
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
for _, endpoint := range inspect.NetworkSettings.Networks {
|
||||
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
|
||||
}
|
||||
|
||||
func firstContainerName(names []string) string {
|
||||
if len(names) == 0 {
|
||||
return ""
|
||||
}
|
||||
return names[0]
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
values := url.Values{}
|
||||
if strings.TrimSpace(routingKey) != "" {
|
||||
values.Set("routing_key", routingKey)
|
||||
}
|
||||
|
||||
var (
|
||||
lastStatus int
|
||||
lastBody string
|
||||
)
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = truncate(string(body), 200)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
|
||||
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
|
||||
}
|
||||
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
|
||||
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastStatus > 0 || lastBody != "" {
|
||||
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
|
||||
}
|
||||
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldWarnWorkerNotReady(status int, body string) bool {
|
||||
if status == http.StatusServiceUnavailable {
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
if strings.Contains(normalized, "model mapping not ready") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isWorkerModelMappingUnavailable(status int, body string) bool {
|
||||
if status != http.StatusServiceUnavailable {
|
||||
return false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
return strings.Contains(normalized, errLSModelMapDenied.Error())
|
||||
}
|
||||
|
||||
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
|
||||
if state == nil {
|
||||
return fmt.Errorf("ls worker state is nil")
|
||||
}
|
||||
body, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker state: %w", err)
|
||||
}
|
||||
|
||||
sum := fmt.Sprintf("%x", sha256.Sum256(body))
|
||||
if handle.LastStateSHA == sum {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
|
||||
resp, err := m.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync worker state: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
handle.LastStateSHA = sum
|
||||
return nil
|
||||
}
|
||||
|
||||
func workerURL(handle *workerHandle, path string, values url.Values) string {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: handle.Address,
|
||||
Path: path,
|
||||
}
|
||||
if values != nil {
|
||||
base.RawQuery = values.Encode()
|
||||
}
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
|
||||
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
|
||||
return nil
|
||||
}
|
||||
timeout := 5
|
||||
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
|
||||
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
|
||||
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
|
||||
state := m.state[accountID]
|
||||
if state == nil {
|
||||
state = &workerAccountState{}
|
||||
m.state[accountID] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
|
||||
resolved := resolveLSProxy(proxyURL)
|
||||
normalized, parsed, err := proxyurl.Parse(resolved)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "socks5", "socks5h":
|
||||
return normalized, parsed, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func proxyHash(proxyURL string) string {
|
||||
if strings.TrimSpace(proxyURL) == "" {
|
||||
return "direct"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
|
||||
return fmt.Sprintf("%x", sum[:])
|
||||
}
|
||||
|
||||
func buildWorkerKey(accountID, proxyHash string) string {
|
||||
return accountID + ":" + proxyHash
|
||||
}
|
||||
|
||||
func cloneInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *state
|
||||
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
|
||||
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
|
||||
if state.ExpiresAt != nil {
|
||||
ts := *state.ExpiresAt
|
||||
cp.ExpiresAt = &ts
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
@ -1,335 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDockerClient struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listResp []container.Summary
|
||||
listCalls int
|
||||
createCalls int
|
||||
startCalls int
|
||||
stopCalls int
|
||||
removeCalls int
|
||||
inspectCalls int
|
||||
removedIDs []string
|
||||
createdConfigs []*container.Config
|
||||
inspectResp container.InspectResponse
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.listCalls++
|
||||
return append([]container.Summary(nil), f.listResp...), nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.createCalls++
|
||||
f.createdConfigs = append(f.createdConfigs, cfg)
|
||||
return container.CreateResponse{ID: "worker-created"}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.startCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inspectCalls++
|
||||
return f.inspectResp, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.stopCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removeCalls++
|
||||
f.removedIDs = append(f.removedIDs, containerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) Close() error { return nil }
|
||||
|
||||
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
|
||||
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "only supports socks5/socks5h")
|
||||
}
|
||||
|
||||
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
|
||||
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
|
||||
|
||||
hash1 := proxyHash(normalized)
|
||||
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestWorkerManagerRequiresToken(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 2,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var healthCalls int
|
||||
var readyCalls int
|
||||
var stateCalls int
|
||||
var stateBodies [][]byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/healthz":
|
||||
mu.Lock()
|
||||
healthCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
case "/readyz":
|
||||
mu.Lock()
|
||||
readyCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
case "/account/state":
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
mu.Lock()
|
||||
stateCalls++
|
||||
stateBodies = append(stateBodies, body)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
accountID := "9"
|
||||
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
|
||||
hash := proxyHash(proxyURL)
|
||||
key := buildWorkerKey(accountID, hash)
|
||||
|
||||
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers[key] = &workerHandle{
|
||||
Key: key,
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: hash,
|
||||
ContainerID: "existing-worker",
|
||||
Container: "sub2api-ls-9-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst1.remote)
|
||||
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
|
||||
|
||||
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst2.remote)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.GreaterOrEqual(t, healthCalls, 2)
|
||||
require.GreaterOrEqual(t, readyCalls, 2)
|
||||
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
|
||||
require.Len(t, stateBodies, 1)
|
||||
|
||||
var synced workerAccountState
|
||||
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
|
||||
require.True(t, synced.HasToken)
|
||||
require.Equal(t, "ya29.test", synced.AccessToken)
|
||||
}
|
||||
|
||||
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
|
||||
manager.mu.Unlock()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "limit reached")
|
||||
require.Equal(t, 0, fakeDocker.createCalls)
|
||||
}
|
||||
|
||||
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{
|
||||
listResp: []container.Summary{
|
||||
{
|
||||
ID: "old-worker-1",
|
||||
Names: []string{"/sub2api-ls-9-deadbeef"},
|
||||
},
|
||||
{
|
||||
ID: "old-worker-2",
|
||||
Names: []string{"/sub2api-ls-10-beadfeed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
require.Equal(t, 1, fakeDocker.listCalls)
|
||||
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
|
||||
}
|
||||
|
||||
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) {
|
||||
require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed"))
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errLSModelMapDenied)
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte("worker model mapping not ready for replica 0\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: 100 * time.Millisecond,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `last_status=503`)
|
||||
require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`)
|
||||
}
|
||||
@ -1,374 +0,0 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WorkerServerConfig struct {
|
||||
AccountID string
|
||||
AuthToken string
|
||||
ListenAddr string
|
||||
AppRoot string
|
||||
NetworkReadyFile string
|
||||
MaxIdleTime time.Duration
|
||||
HealthInterval time.Duration
|
||||
}
|
||||
|
||||
type WorkerServer struct {
|
||||
cfg WorkerServerConfig
|
||||
pool *Pool
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
state workerAccountState
|
||||
}
|
||||
|
||||
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
|
||||
if strings.TrimSpace(cfg.AccountID) == "" {
|
||||
return nil, fmt.Errorf("worker account id is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AuthToken) == "" {
|
||||
return nil, fmt.Errorf("worker auth token is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
||||
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
|
||||
}
|
||||
if strings.TrimSpace(cfg.AppRoot) == "" {
|
||||
cfg.AppRoot = "/app/ls"
|
||||
}
|
||||
if cfg.MaxIdleTime <= 0 {
|
||||
cfg.MaxIdleTime = 15 * time.Minute
|
||||
}
|
||||
if cfg.HealthInterval <= 0 {
|
||||
cfg.HealthInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
poolCfg := DefaultConfig()
|
||||
poolCfg.AppRoot = cfg.AppRoot
|
||||
poolCfg.MaxIdleTime = cfg.MaxIdleTime
|
||||
poolCfg.HealthCheckInterval = cfg.HealthInterval
|
||||
|
||||
return &WorkerServer{
|
||||
cfg: cfg,
|
||||
pool: NewPool(poolCfg),
|
||||
logger: slog.Default().With("component", "lsworker"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewWorkerServerFromEnv() (*WorkerServer, error) {
|
||||
maxIdleTime := 15 * time.Minute
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
maxIdleTime = parsed
|
||||
}
|
||||
}
|
||||
healthInterval := 30 * time.Second
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
healthInterval = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return NewWorkerServer(WorkerServerConfig{
|
||||
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
|
||||
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
|
||||
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
|
||||
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
|
||||
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
|
||||
MaxIdleTime: maxIdleTime,
|
||||
HealthInterval: healthInterval,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Close() {
|
||||
if s.pool != nil {
|
||||
s.pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", s.handleHealthz)
|
||||
mux.HandleFunc("/readyz", s.handleReadyz)
|
||||
mux.HandleFunc("/account/state", s.handleAccountState)
|
||||
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
|
||||
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
|
||||
return mux
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var payload workerAccountState
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "invalid account state payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.state = *cloneWorkerAccountState(&payload)
|
||||
s.mu.Unlock()
|
||||
s.applyState()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
respBody, err = inst.CallRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(respBody)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
resp, err = inst.StreamRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
|
||||
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
|
||||
return true
|
||||
}
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
|
||||
func subtleHeaderEqual(left, right string) bool {
|
||||
if left == "" || right == "" {
|
||||
return false
|
||||
}
|
||||
return left == right
|
||||
}
|
||||
|
||||
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
query := r.URL.Query()
|
||||
service = strings.TrimSpace(query.Get("service"))
|
||||
method = strings.TrimSpace(query.Get("method"))
|
||||
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
|
||||
routingKey = strings.TrimSpace(query.Get("routing_key"))
|
||||
if service == "" || method == "" {
|
||||
http.Error(w, "missing rpc target", http.StatusBadRequest)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "proto"
|
||||
}
|
||||
return service, method, mode, routingKey, true
|
||||
}
|
||||
|
||||
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
|
||||
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, fmt.Errorf("worker network not ready: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.applyState()
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("worker access token not configured")
|
||||
}
|
||||
|
||||
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
if inst.HasModelMappingReady() {
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
|
||||
defer cancel()
|
||||
_ = modelCtx
|
||||
if !RefreshModelMapping(inst) {
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (s *WorkerServer) applyState() {
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
if state.HasToken {
|
||||
expiresAt := time.Time{}
|
||||
if state.ExpiresAt != nil {
|
||||
expiresAt = state.ExpiresAt.UTC()
|
||||
}
|
||||
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
|
||||
}
|
||||
if state.HasModelCredits {
|
||||
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
|
||||
}
|
||||
}
|
||||
|
||||
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func workerExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func parseWorkerControlPort() int {
|
||||
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
|
||||
if raw == "" {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
port, err := strconv.Atoi(raw)
|
||||
if err != nil || port < 1 {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
return port
|
||||
}
|
||||
@ -1,11 +1,20 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
anthropicEventLoggingURL = "https://api.anthropic.com/api/event_logging/batch"
|
||||
eventLoggingForwardTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
|
||||
func RegisterCommonRoutes(r *gin.Engine) {
|
||||
// 健康检查
|
||||
@ -13,8 +22,36 @@ func RegisterCommonRoutes(r *gin.Engine) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// Claude Code 遥测日志(忽略,直接返回200)
|
||||
// Claude Code 遥测日志:清理敏感字段后转发给 Anthropic。
|
||||
// 删除 baseUrl/gateway 字段防止网关地址暴露(见 FINGERPRINT_SECURITY_REPORT.md §GAP-1/2)。
|
||||
// 转发而非丢弃,避免"高流量零遥测"异常被检测。
|
||||
r.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil || len(body) == 0 {
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := sanitizeEventBatch(body)
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), eventLoggingForwardTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, anthropicEventLoggingURL, bytes.NewReader(sanitized))
|
||||
if err != nil {
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// 透传客户端的 Authorization header(OAuth Bearer token)
|
||||
if auth := c.GetHeader("Authorization"); auth != "" {
|
||||
req.Header.Set("Authorization", auth)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
|
||||
@ -4,6 +4,9 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -12,7 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// Attribution block constants matching real Claude Code 2.1.88.
|
||||
// Attribution block constants matching real Claude Code 2.1.89.
|
||||
// Source: src/constants/system.ts + src/utils/fingerprint.ts
|
||||
const (
|
||||
// fingerprintSalt must match the hardcoded salt in the real CLI.
|
||||
@ -81,11 +84,10 @@ func extractFirstUserMessageText(body []byte) string {
|
||||
// Source: extracted/src/constants/system.ts:73-95
|
||||
func buildAttributionBlock(cliVersion, fingerprint string) string {
|
||||
version := cliVersion + "." + fingerprint
|
||||
// 注意:cch 字段由 Bun 的 NATIVE_CLIENT_ATTESTATION 编译时 feature 控制。
|
||||
// npm 安装版本(非原生二进制)此 feature 为 false,所以不包含 cch 字段。
|
||||
// 只有原生二进制安装(Bun 打包)才会有 cch,且其值会被 Bun 的 Zig 层替换为真实 hash。
|
||||
// 我们模拟 npm 安装版本的行为:不包含 cch。
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli;", version)
|
||||
// 2.1.89 起 cch=00000 出现在所有安装模式(含 npm 版),不再只限于原生二进制。
|
||||
// 原生二进制由 Bun 的 Zig 层在运行时将 00000 替换为真实 attestation hash;
|
||||
// 普通安装版保持 00000 占位符不变。
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli; cch=00000;", version)
|
||||
}
|
||||
|
||||
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
|
||||
@ -163,20 +165,89 @@ func injectAttributionBlock(body []byte, cliVersion string) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionIDForAccount generates a deterministic per-account session UUID
|
||||
// that remains stable within a process-like timeframe.
|
||||
// Uses instanceSalt + accountID to ensure uniqueness across sub2api instances.
|
||||
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
||||
// Use a per-account stable UUID (like real CLI's per-process UUID).
|
||||
// We use accountID as the base — each account gets a different "session".
|
||||
seed := fmt.Sprintf("session:%s:%d", instanceSalt, accountID)
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
sessionUUID, err := uuid.FromBytes(hash[:16])
|
||||
if err != nil {
|
||||
return uuid.New().String()
|
||||
}
|
||||
// Set UUID v4 variant
|
||||
sessionUUID[6] = (sessionUUID[6] & 0x0f) | 0x40
|
||||
sessionUUID[8] = (sessionUUID[8] & 0x3f) | 0x80
|
||||
return sessionUUID.String()
|
||||
// cliSessionEntry holds a cached session UUID with an expiration time.
|
||||
type cliSessionEntry struct {
|
||||
id string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// cliSessionCache stores per-account session UUIDs that rotate on a TTL.
|
||||
// Real CLI creates a new random UUID per process invocation; we approximate
|
||||
// this by rotating every 30-60 minutes (jittered per account).
|
||||
var (
|
||||
cliSessionCache = make(map[int64]cliSessionEntry)
|
||||
cliSessionCacheMu sync.Mutex
|
||||
)
|
||||
|
||||
// sessionTTLBase is the base TTL for session ID rotation.
|
||||
const sessionTTLBase = 30 * time.Minute
|
||||
|
||||
// generateSessionIDForAccount returns a per-account session UUID that rotates
|
||||
// periodically. Each account gets a random TTL jitter (0-30 min on top of
|
||||
// the 30 min base) so accounts don't all rotate simultaneously.
|
||||
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
||||
cliSessionCacheMu.Lock()
|
||||
defer cliSessionCacheMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) {
|
||||
return entry.id
|
||||
}
|
||||
|
||||
// Compute per-account jitter from a hash so the same account always gets
|
||||
// the same jitter within a process (avoids re-rolling on every rotation).
|
||||
jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID)
|
||||
h := sha256.Sum256([]byte(jitterSeed))
|
||||
jitterMinutes := int(h[0]) % 31 // 0-30 minutes
|
||||
ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute
|
||||
|
||||
newID := uuid.New().String()
|
||||
cliSessionCache[accountID] = cliSessionEntry{
|
||||
id: newID,
|
||||
expiresAt: now.Add(ttl),
|
||||
}
|
||||
return newID
|
||||
}
|
||||
|
||||
// reUserHome matches /Users/<username>/ or /home/<username>/ path segments.
|
||||
// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username.
|
||||
var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`)
|
||||
|
||||
// reEnvLine matches lines of the form "Key: value" for the environment block
|
||||
// fields injected by Claude Code's CLAUDE.md / sysprompt machinery.
|
||||
var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`)
|
||||
|
||||
// canonicalEnvValues maps environment block keys to their canonical replacements.
|
||||
// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine.
|
||||
var canonicalEnvValues = map[string]string{
|
||||
"Platform": "Platform: darwin",
|
||||
"Shell": "Shell: zsh",
|
||||
"OS Version": "OS Version: Darwin 24.4.0",
|
||||
"Working directory": "Working directory: /Users/user/project",
|
||||
}
|
||||
|
||||
// NormalizeSystemPromptEnv rewrites environment-specific fields in a system
|
||||
// prompt text block to canonical values, preventing real machine fingerprinting.
|
||||
//
|
||||
// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText):
|
||||
// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0
|
||||
// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/"
|
||||
//
|
||||
// Only called on system prompt text blocks, never on user message content.
|
||||
func NormalizeSystemPromptEnv(text string) string {
|
||||
// Replace env-info lines with canonical values
|
||||
text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string {
|
||||
for key, canonical := range canonicalEnvValues {
|
||||
if len(line) >= len(key) && line[:len(key)] == key {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
return line
|
||||
})
|
||||
|
||||
// Redact real usernames in home directory paths
|
||||
// e.g. /Users/alice/project -> /Users/user/project
|
||||
text = reUserHome.ReplaceAllString(text, "${1}user/")
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
@ -895,6 +895,9 @@ func sanitizeSystemText(text string) string {
|
||||
"You are OpenCode, the best coding agent on the planet.",
|
||||
strings.TrimSpace(claudeCodeSystemPrompt),
|
||||
)
|
||||
// Normalize environment block fields (Platform/Shell/OS Version/Working directory)
|
||||
// to canonical values so different client machines don't create fingerprint divergence.
|
||||
text = NormalizeSystemPromptEnv(text)
|
||||
return text
|
||||
}
|
||||
|
||||
@ -5773,7 +5776,7 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
|
||||
return claude.HaikuBetaHeader
|
||||
}
|
||||
|
||||
return claude.DefaultBetaHeader
|
||||
return claude.GetOAuthBetaHeader(modelID)
|
||||
}
|
||||
|
||||
func requestNeedsBetaFeatures(body []byte) bool {
|
||||
@ -5790,10 +5793,7 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.APIKeyBetaHeader
|
||||
return claude.GetAPIKeyBetaHeader(modelID)
|
||||
}
|
||||
|
||||
func applyClaudeOAuthHeaderDefaults(req *http.Request) {
|
||||
|
||||
@ -26,7 +26,7 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.1.88 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.89 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.74.0",
|
||||
StainlessOS: "MacOS",
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLSPoolBootstrapConcurrency = 4
|
||||
)
|
||||
|
||||
type lsBootstrapAccountReader interface {
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
}
|
||||
|
||||
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
|
||||
type LSPoolBootstrapService struct {
|
||||
accountReader lsBootstrapAccountReader
|
||||
backend lspool.Backend
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &LSPoolBootstrapService{
|
||||
accountReader: accountReader,
|
||||
backend: backend,
|
||||
cfg: cfg,
|
||||
logger: slog.Default().With("component", "service.lspool_bootstrap"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
|
||||
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
|
||||
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.once.Do(func() {
|
||||
if s.backend == nil {
|
||||
if lspool.IsLSModeEnabled() {
|
||||
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
|
||||
}
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.bootstrap(s.ctx)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
|
||||
if s.backend == nil || s.accountReader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
|
||||
if err != nil {
|
||||
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
candidates := make([]Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
|
||||
candidates = append(candidates, accounts[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("starting ls worker bootstrap",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"concurrency", s.bootstrapConcurrency())
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
started int
|
||||
failed int
|
||||
)
|
||||
sem := make(chan struct{}, s.bootstrapConcurrency())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
loop:
|
||||
for i := range candidates {
|
||||
account := candidates[i]
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
case sem <- struct{}{}:
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(account Account) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.bootstrapAccount(&account); err != nil {
|
||||
mu.Lock()
|
||||
failed++
|
||||
mu.Unlock()
|
||||
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
started++
|
||||
mu.Unlock()
|
||||
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
|
||||
}(account)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
s.logger.Info("ls worker bootstrap completed",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"workers_ready", started,
|
||||
"workers_failed", failed,
|
||||
"canceled", ctx.Err() != nil)
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
|
||||
if s.backend == nil {
|
||||
return fmt.Errorf("ls backend unavailable")
|
||||
}
|
||||
if account == nil {
|
||||
return fmt.Errorf("account is nil")
|
||||
}
|
||||
|
||||
accountKey := strconv.FormatInt(account.ID, 10)
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("missing access token")
|
||||
}
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
|
||||
expiresAt = ts.UTC()
|
||||
}
|
||||
|
||||
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
||||
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
|
||||
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
|
||||
return fmt.Errorf("get or create ls worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
|
||||
parallelism := defaultLSPoolBootstrapConcurrency
|
||||
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
|
||||
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
|
||||
}
|
||||
if parallelism < 1 {
|
||||
return 1
|
||||
}
|
||||
return parallelism
|
||||
}
|
||||
|
||||
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
if account.Status != StatusActive || !account.Schedulable {
|
||||
return false
|
||||
}
|
||||
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(account.GetCredential("project_id")) != ""
|
||||
}
|
||||
@ -1,262 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeLSBootstrapAccountReader struct {
|
||||
mu sync.Mutex
|
||||
accounts []Account
|
||||
err error
|
||||
platforms []string
|
||||
}
|
||||
|
||||
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
|
||||
f.mu.Lock()
|
||||
f.platforms = append(f.platforms, platform)
|
||||
accounts := append([]Account(nil), f.accounts...)
|
||||
err := f.err
|
||||
f.mu.Unlock()
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
type fakeLSPoolBackend struct {
|
||||
mu sync.Mutex
|
||||
tokenCalls map[string]fakeLSPoolTokenCall
|
||||
creditCalls map[string]fakeLSPoolCreditCall
|
||||
getCalls []fakeLSPoolGetCall
|
||||
getErrs map[string]error
|
||||
}
|
||||
|
||||
type fakeLSPoolTokenCall struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type fakeLSPoolCreditCall struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmount *int32
|
||||
}
|
||||
|
||||
type fakeLSPoolGetCall struct {
|
||||
AccountID string
|
||||
RoutingKey string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
func newFakeLSPoolBackend() *fakeLSPoolBackend {
|
||||
return &fakeLSPoolBackend{
|
||||
tokenCalls: make(map[string]fakeLSPoolTokenCall),
|
||||
creditCalls: make(map[string]fakeLSPoolCreditCall),
|
||||
getErrs: make(map[string]error),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
|
||||
AccountID: accountID,
|
||||
RoutingKey: routingKey,
|
||||
ProxyURL: rawProxy,
|
||||
})
|
||||
if err := f.getErrs[accountID]; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &lspool.Instance{AccountID: accountID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.creditCalls[accountID] = fakeLSPoolCreditCall{
|
||||
UseAICredits: useAICredits,
|
||||
AvailableCredits: copyInt32Ptr(availableCredits),
|
||||
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
|
||||
|
||||
func (f *fakeLSPoolBackend) Close() {}
|
||||
|
||||
func copyInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
|
||||
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||
expiredAt := time.Now().Add(-2 * time.Hour)
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 101,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-101",
|
||||
"refresh_token": "refresh-101",
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"project_id": "proj-101",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"ai_credits": []any{
|
||||
map[string]any{
|
||||
"credit_type": "GOOGLE_ONE_AI",
|
||||
"amount": 120,
|
||||
"minimum_balance": 55,
|
||||
},
|
||||
},
|
||||
},
|
||||
Proxy: &Proxy{
|
||||
Protocol: "socks5h",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1080,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 102,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
|
||||
},
|
||||
{
|
||||
ID: 103,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-103"},
|
||||
},
|
||||
{
|
||||
ID: 104,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &expiredAt,
|
||||
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
|
||||
},
|
||||
{
|
||||
ID: 106,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
|
||||
},
|
||||
{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-105"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
|
||||
},
|
||||
})
|
||||
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
|
||||
|
||||
require.Len(t, backend.getCalls, 1)
|
||||
require.Equal(t, fakeLSPoolGetCall{
|
||||
AccountID: "101",
|
||||
RoutingKey: "",
|
||||
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
|
||||
}, backend.getCalls[0])
|
||||
|
||||
tokenCall, ok := backend.tokenCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "token-101", tokenCall.AccessToken)
|
||||
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
|
||||
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
|
||||
|
||||
creditCall, ok := backend.creditCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.True(t, creditCall.UseAICredits)
|
||||
require.NotNil(t, creditCall.AvailableCredits)
|
||||
require.Equal(t, int32(120), *creditCall.AvailableCredits)
|
||||
require.NotNil(t, creditCall.MinimumCreditAmount)
|
||||
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
|
||||
|
||||
require.NotContains(t, backend.tokenCalls, "102")
|
||||
require.NotContains(t, backend.tokenCalls, "103")
|
||||
require.NotContains(t, backend.tokenCalls, "104")
|
||||
require.NotContains(t, backend.tokenCalls, "106")
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 201,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
|
||||
},
|
||||
{
|
||||
ID: 202,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
backend.getErrs["201"] = errors.New("create failed")
|
||||
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Len(t, backend.getCalls, 2)
|
||||
require.Contains(t, backend.tokenCalls, "201")
|
||||
require.Contains(t, backend.tokenCalls, "202")
|
||||
}
|
||||
@ -1,21 +0,0 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
|
||||
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
|
||||
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
|
||||
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
|
||||
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
|
||||
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
|
||||
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
|
||||
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
|
||||
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
|
||||
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
|
||||
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
|
||||
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
|
||||
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
|
||||
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
|
||||
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
|
||||
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
|
||||
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
|
||||
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
|
||||
/Q==
|
||||
-----END CERTIFICATE-----
|
||||
Binary file not shown.
Binary file not shown.
@ -1,70 +0,0 @@
|
||||
#!/bin/sh
|
||||
set -eu
|
||||
|
||||
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
|
||||
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
|
||||
PROXY_USER="${LSWORKER_PROXY_USER:-}"
|
||||
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
|
||||
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
|
||||
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
|
||||
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
|
||||
|
||||
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
|
||||
|
||||
if [ -z "${PROXY_HOST}" ]; then
|
||||
echo "LSWORKER_PROXY_HOST is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
|
||||
if [ -z "${PROXY_IP}" ]; then
|
||||
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat >/tmp/redsocks.conf <<EOF
|
||||
base {
|
||||
log_debug = off;
|
||||
log_info = on;
|
||||
daemon = off;
|
||||
redirector = iptables;
|
||||
}
|
||||
|
||||
redsocks {
|
||||
local_ip = 0.0.0.0;
|
||||
local_port = ${REDSOCKS_PORT};
|
||||
ip = ${PROXY_IP};
|
||||
port = ${PROXY_PORT};
|
||||
type = socks5;
|
||||
EOF
|
||||
|
||||
if [ -n "${PROXY_USER}" ]; then
|
||||
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
if [ -n "${PROXY_PASS}" ]; then
|
||||
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
|
||||
cat >>/tmp/redsocks.conf <<EOF
|
||||
}
|
||||
EOF
|
||||
|
||||
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
|
||||
REDSOCKS_PID="$!"
|
||||
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
|
||||
|
||||
sleep 1
|
||||
|
||||
iptables -t nat -N REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -F REDSOCKS
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
|
||||
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
|
||||
|
||||
touch "${NETWORK_READY_FILE}"
|
||||
|
||||
exec gosu sub2api /app/lsworker
|
||||
@ -1,52 +0,0 @@
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
|
||||
FROM ${GOLANG_IMAGE} AS builder
|
||||
|
||||
WORKDIR /app/backend
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY backend/ ./
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
|
||||
|
||||
FROM ${DEBIAN_IMAGE}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gosu \
|
||||
iproute2 \
|
||||
iptables \
|
||||
redsocks \
|
||||
tzdata \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN groupadd -g 1000 sub2api && \
|
||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/lsworker /app/lsworker
|
||||
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
|
||||
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
chown -R sub2api:sub2api /app /run/lsworker && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
|
||||
RUN chmod +x /app/lsworker-entrypoint.sh
|
||||
|
||||
EXPOSE 18081
|
||||
|
||||
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]
|
||||
Loading…
x
Reference in New Issue
Block a user