Compare commits
16 Commits
b6e1c64c25
...
b0ed2eefb6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0ed2eefb6 | ||
|
|
78f91da858 | ||
|
|
1a6a077743 | ||
|
|
1182647a59 | ||
|
|
b285fb7b2f | ||
|
|
2f817dd248 | ||
|
|
0df29af0ab | ||
|
|
71bafae881 | ||
|
|
35b0d85d0d | ||
|
|
95210a1023 | ||
|
|
e301fbc46f | ||
|
|
d6e2d1ee7f | ||
|
|
8e54eaa002 | ||
|
|
20151b3347 | ||
|
|
6694dcad14 | ||
|
|
648e617f4e |
46
Dockerfile
46
Dockerfile
@ -9,6 +9,7 @@
|
||||
ARG NODE_IMAGE=node:24-alpine
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG ALPINE_IMAGE=alpine:3.21
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
ARG POSTGRES_IMAGE=postgres:18-alpine
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
ARG GOSUMDB=sum.golang.google.cn
|
||||
@ -63,10 +64,12 @@ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||
|
||||
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||
# Version precedence: build arg VERSION > cmd/server/VERSION
|
||||
ARG TARGETARCH
|
||||
ARG TARGETOS=linux
|
||||
RUN VERSION_VALUE="${VERSION}" && \
|
||||
if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \
|
||||
DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \
|
||||
CGO_ENABLED=0 GOOS=linux go build \
|
||||
CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build \
|
||||
-tags embed \
|
||||
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
|
||||
-trimpath \
|
||||
@ -79,9 +82,9 @@ RUN VERSION_VALUE="${VERSION}" && \
|
||||
FROM ${POSTGRES_IMAGE} AS pg-client
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stage 4: Final Runtime Image
|
||||
# Stage 4: Final Runtime Image (Debian for glibc — LS binary requires it)
|
||||
# -----------------------------------------------------------------------------
|
||||
FROM ${ALPINE_IMAGE}
|
||||
FROM ${DEBIAN_IMAGE}
|
||||
|
||||
# Labels
|
||||
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||
@ -89,27 +92,25 @@ LABEL description="Sub2API - AI API Gateway Platform"
|
||||
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
wget \
|
||||
gosu \
|
||||
proxychains4 \
|
||||
tzdata \
|
||||
su-exec \
|
||||
libpq \
|
||||
zstd-libs \
|
||||
lz4-libs \
|
||||
krb5-libs \
|
||||
libldap \
|
||||
libedit \
|
||||
&& rm -rf /var/cache/apk/*
|
||||
libpq5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy pg_dump and psql from the same postgres image used in docker-compose
|
||||
# This ensures version consistency between backup tools and the database server
|
||||
COPY --from=pg-client /usr/local/bin/pg_dump /usr/local/bin/pg_dump
|
||||
COPY --from=pg-client /usr/local/bin/psql /usr/local/bin/psql
|
||||
COPY --from=pg-client /usr/local/lib/libpq.so.5* /usr/local/lib/
|
||||
RUN ldconfig
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 sub2api && \
|
||||
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||
RUN groupadd -g 1000 sub2api && \
|
||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
@ -118,6 +119,21 @@ WORKDIR /app
|
||||
COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api
|
||||
COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources
|
||||
|
||||
# Copy Language Server binary and cert (for LS pool mode)
|
||||
# Enable with: ANTIGRAVITY_LS_MODE=true ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
# TARGETARCH is set automatically by buildx (amd64 or arm64)
|
||||
ARG TARGETARCH
|
||||
COPY --chown=sub2api:sub2api deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY --chown=sub2api:sub2api deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin && \
|
||||
if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
|
||||
|
||||
|
||||
898
antigravity/node-tls-proxy/proxy.js
Normal file
898
antigravity/node-tls-proxy/proxy.js
Normal file
@ -0,0 +1,898 @@
|
||||
'use strict';
|
||||
|
||||
const http = require('http');
|
||||
const https = require('https');
|
||||
const http2 = require('http2');
|
||||
const net = require('net');
|
||||
const crypto = require('crypto');
|
||||
// os 模块不引用 — 避免暴露真实主机信息
|
||||
|
||||
// ─── 配置 ───────────────────────────────────────────────
|
||||
const UPSTREAM_HOST = process.env.UPSTREAM_HOST || 'api.anthropic.com';
|
||||
const LISTEN_PORT = parseInt(process.env.PROXY_PORT || '3456', 10);
|
||||
const LISTEN_HOST = process.env.PROXY_HOST || '127.0.0.1';
|
||||
const UPSTREAM_PROXY = process.env.UPSTREAM_PROXY || '';
|
||||
const CONNECT_TIMEOUT = parseInt(process.env.CONNECT_TIMEOUT || '30000', 10);
|
||||
const IDLE_TIMEOUT = parseInt(process.env.IDLE_TIMEOUT || '600000', 10);
|
||||
const TELEMETRY_ENABLED = process.env.TELEMETRY_ENABLED !== 'false'; // 默认开启
|
||||
const DD_API_KEY = process.env.DD_API_KEY || 'pubbbf48e6d78dae54bceaa4acf463299bf';
|
||||
const CLI_VERSION = process.env.CLI_VERSION || '2.1.88';
|
||||
const BUILD_TIME = process.env.BUILD_TIME || '2026-03-31T01:39:46Z';
|
||||
// 伪装的 Node 版本(CLI 2.1.88 打包的 Bun 报告的 Node 兼容版本)
|
||||
const FAKE_NODE_VERSION = process.env.FAKE_NODE_VERSION || 'v24.3.0';
|
||||
|
||||
const log = (level, msg, extra = {}) => {
|
||||
const entry = { time: new Date().toISOString(), level, msg, ...extra };
|
||||
process.stderr.write(JSON.stringify(entry) + '\n');
|
||||
};
|
||||
|
||||
const HEALTH_PATH = '/__health';
|
||||
const h2Hosts = new Set();
|
||||
|
||||
// Strip userinfo (user:pass) from proxy URL for safe logging
|
||||
function redactProxyURL(raw) {
|
||||
if (!raw) return '';
|
||||
try {
|
||||
const u = new URL(raw);
|
||||
u.username = '';
|
||||
u.password = '';
|
||||
return u.toString();
|
||||
} catch { return '<redacted-proxy-url>'; }
|
||||
}
|
||||
const h2Sessions = new Map();
|
||||
|
||||
// ─── 虚拟主机身份生成 ─────────────────────────────────────
|
||||
// 每个账号基于 seed 生成全局唯一的主机身份,看起来像一台真实的个人开发机
|
||||
// 匹配 CLI 的 OTEL detectResources: hostDetector + processDetector + serviceInstanceIdDetector
|
||||
//
|
||||
// 设计原则:
|
||||
// 1. 同一账号(seed)永远产出同一台"机器"的特征
|
||||
// 2. 不同账号的特征互不相同(无共享池、无碰撞)
|
||||
// 3. 每个字段都像人手动设置的,不是程序生成的
|
||||
|
||||
// ─── macOS 主机身份词表 ──────────────────────────────────────────
|
||||
// macOS 用户 hostname 习惯: "alex-MBP", "sam-MacBook-Pro" 等
|
||||
const MBP_NAMES = ['alex','sam','chris','max','lee','kai','jamie','taylor','morgan','casey',
|
||||
'drew','avery','riley','blake','jordan','ryan','parker','quinn','reese','cameron'];
|
||||
const MBP_SUFFIX = ['-MBP','-MacBook','-MacBook-Pro','-MacBook-Air',"s-MBP","s-MacBook","s-MacBook-Pro"];
|
||||
|
||||
function generateHostIdentity(seed) {
|
||||
const h = (s) => crypto.createHash('sha256').update(seed + ':' + s).digest();
|
||||
|
||||
// ── hostname: macOS 风格 ──
|
||||
const hb = h('hostname');
|
||||
const name = MBP_NAMES[hb.readUInt8(0) % MBP_NAMES.length];
|
||||
const sfx = MBP_SUFFIX[hb.readUInt8(1) % MBP_SUFFIX.length];
|
||||
const hostname = `${name}${sfx}`;
|
||||
|
||||
// ── username: 取自 hostname 名字(真实 Mac 行为) ──
|
||||
const username = name;
|
||||
|
||||
// ── terminal: macOS 常见终端分布 ──
|
||||
const termRoll = h('terminal').readUInt8(0) % 100;
|
||||
const terminal = termRoll < 75 ? 'xterm-256color' :
|
||||
termRoll < 88 ? 'screen-256color' :
|
||||
termRoll < 96 ? 'alacritty' : 'kitty';
|
||||
|
||||
// ── shell: macOS 默认 zsh(Catalina+);部分用 bash/fish ──
|
||||
const shellRoll = h('shell').readUInt8(0) % 100;
|
||||
const shell = shellRoll < 65 ? '/bin/zsh' :
|
||||
shellRoll < 82 ? '/usr/local/bin/zsh' :
|
||||
shellRoll < 93 ? '/bin/bash' : '/opt/homebrew/bin/fish';
|
||||
|
||||
// ── host.id: macOS IOPlatformUUID 格式(大写 UUID) ──
|
||||
const mid = h('machine-id');
|
||||
const machineId = [
|
||||
mid.slice(0,4).toString('hex').toUpperCase(),
|
||||
mid.slice(4,6).toString('hex').toUpperCase(),
|
||||
mid.slice(6,8).toString('hex').toUpperCase(),
|
||||
mid.slice(8,10).toString('hex').toUpperCase(),
|
||||
mid.slice(10,16).toString('hex').toUpperCase(),
|
||||
].join('-');
|
||||
|
||||
// ── PID: macOS GUI 应用 PID 通常较小 ──
|
||||
const pid = 500 + Math.floor(Math.random() * 8000);
|
||||
|
||||
// ── macOS 版本: 13(Ventura)/14(Sonoma)/15(Sequoia) ──
|
||||
const kb = h('kernel');
|
||||
const macosMajor = 13 + (kb.readUInt8(0) % 3);
|
||||
const macosMinor = kb.readUInt8(1) % 8;
|
||||
const macosPatch = kb.readUInt8(2) % 5;
|
||||
// Darwin 内核: macOS 13=22.x, 14=23.x, 15=24.x
|
||||
const darwinMajor = 22 + (macosMajor - 13);
|
||||
const darwinMinor = kb.readUInt8(3) % 7;
|
||||
const darwinPatch = kb.readUInt8(4) % 5;
|
||||
const osVersion = `${macosMajor}.${macosMinor}.${macosPatch}`;
|
||||
|
||||
// ── arch: Apple Silicon arm64 占 70%,Intel x64 占 30% ──
|
||||
const arch = h('arch').readUInt8(0) % 100 < 70 ? 'arm64' : 'x64';
|
||||
|
||||
// ── 可执行文件路径: macOS 常见安装位置 ──
|
||||
const pathRoll = h('execpath').readUInt8(0) % 100;
|
||||
const executablePath = pathRoll < 50 ? `/Users/${username}/.claude/local/claude` :
|
||||
pathRoll < 80 ? '/usr/local/bin/claude' :
|
||||
pathRoll < 95 ? `/Users/${username}/.local/bin/claude` :
|
||||
'/opt/homebrew/bin/claude';
|
||||
|
||||
return {
|
||||
hostname, username, terminal, shell, machineId, pid, arch,
|
||||
osType: 'Darwin',
|
||||
osVersion,
|
||||
kernelRelease: `${darwinMajor}.${darwinMinor}.${darwinPatch}`,
|
||||
serviceInstanceId: crypto.randomUUID(),
|
||||
executablePath,
|
||||
executableName: 'claude',
|
||||
command: 'claude',
|
||||
commandArgs: [],
|
||||
runtimeName: 'nodejs',
|
||||
runtimeVersion: FAKE_NODE_VERSION.replace('v', ''),
|
||||
ripgrepVersion: (() => {
|
||||
const rv = h('ripgrep');
|
||||
return ['14.1.1','14.1.0','14.0.2','13.0.0','13.0.1','14.0.1','14.0.0'][rv.readUInt8(0) % 7];
|
||||
})(),
|
||||
ripgrepPath: (() => {
|
||||
const rp = h('rgpath');
|
||||
return [
|
||||
'/opt/homebrew/bin/rg',
|
||||
'/usr/local/bin/rg',
|
||||
`/Users/${username}/.cargo/bin/rg`,
|
||||
'/usr/local/opt/ripgrep/bin/rg',
|
||||
][rp.readUInt8(0) % 4];
|
||||
})(),
|
||||
mcpServerCount: 1 + (h('mcp').readUInt8(0) % 5),
|
||||
mcpFailCount: h('mcp').readUInt8(1) % 3,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── 遥测模拟 ────────────────────────────────────────────
|
||||
|
||||
// 每个 device_id 的会话状态
|
||||
const sessionStates = new Map();
|
||||
|
||||
function getOrCreateSession(deviceId) {
|
||||
if (sessionStates.has(deviceId)) return sessionStates.get(deviceId);
|
||||
const hostId = generateHostIdentity(deviceId);
|
||||
const state = {
|
||||
sessionId: crypto.randomUUID(),
|
||||
deviceId,
|
||||
hostId,
|
||||
startTime: Date.now(),
|
||||
requestCount: 0,
|
||||
// 追踪 ripgrep 是否已上报
|
||||
ripgrepReported: false,
|
||||
};
|
||||
sessionStates.set(deviceId, state);
|
||||
return state;
|
||||
}
|
||||
|
||||
function generateDeviceId(accountSeed) {
|
||||
return crypto.createHash('sha256').update(`device:${accountSeed}`).digest('hex');
|
||||
}
|
||||
|
||||
// ─── OTEL Resource Attributes (匹配 CLI 的 detectResources) ───
|
||||
|
||||
function buildEnvBlock(hostId) {
|
||||
const platformStr = 'darwin';
|
||||
return {
|
||||
platform: platformStr,
|
||||
node_version: FAKE_NODE_VERSION,
|
||||
terminal: hostId.terminal,
|
||||
package_managers: 'npm,pnpm',
|
||||
runtimes: 'deno,node',
|
||||
is_running_with_bun: true,
|
||||
is_ci: false,
|
||||
is_claubbit: false,
|
||||
is_github_action: false,
|
||||
is_claude_code_action: false,
|
||||
is_claude_ai_auth: false,
|
||||
version: CLI_VERSION,
|
||||
arch: hostId.arch,
|
||||
is_claude_code_remote: false,
|
||||
deployment_environment: `unknown-${platformStr}`,
|
||||
is_conductor: false,
|
||||
version_base: CLI_VERSION,
|
||||
build_time: BUILD_TIME,
|
||||
is_local_agent_mode: false,
|
||||
vcs: 'git',
|
||||
platform_raw: platformStr,
|
||||
};
|
||||
}
|
||||
|
||||
function buildProcessMetrics(uptime) {
|
||||
// 模拟真实 CLI 的内存曲线:RSS 随 uptime 缓慢增长
|
||||
const baseRss = 180_000_000 + Math.min(uptime * 50_000, 200_000_000);
|
||||
const rss = Math.floor(baseRss + Math.random() * 80_000_000);
|
||||
const heapTotal = Math.floor(rss * 0.6 + Math.random() * 10_000_000);
|
||||
const heapUsed = Math.floor(heapTotal * 0.5 + Math.random() * heapTotal * 0.3);
|
||||
return Buffer.from(JSON.stringify({
|
||||
uptime,
|
||||
rss,
|
||||
heapTotal,
|
||||
heapUsed,
|
||||
external: 14_000_000 + Math.floor(Math.random() * 2_000_000),
|
||||
arrayBuffers: Math.floor(Math.random() * 200_000),
|
||||
constrainedMemory: 51539607552,
|
||||
cpuUsage: {
|
||||
user: Math.floor(uptime * 10_000 + Math.random() * 300_000),
|
||||
system: Math.floor(uptime * 2_000 + Math.random() * 80_000),
|
||||
},
|
||||
cpuPercent: Math.random() * 200,
|
||||
})).toString('base64');
|
||||
}
|
||||
|
||||
function buildEvent(eventName, session, model, betas, extraData, timestampOverride) {
|
||||
const uptime = (Date.now() - session.startTime) / 1000;
|
||||
const processMetrics = buildProcessMetrics(uptime);
|
||||
// 缓存最近一次的 process metrics,供 DataDog 日志复用(保持两边一致)
|
||||
session._lastProcessMetrics = { uptime, raw: processMetrics };
|
||||
const eventData = {
|
||||
event_name: eventName,
|
||||
client_timestamp: timestampOverride || new Date().toISOString(),
|
||||
model: model || 'claude-sonnet-4-6',
|
||||
session_id: session.sessionId,
|
||||
user_type: 'external',
|
||||
betas: betas || 'claude-code-20250219,interleaved-thinking-2025-05-14',
|
||||
env: buildEnvBlock(session.hostId),
|
||||
entrypoint: 'cli',
|
||||
is_interactive: true,
|
||||
client_type: 'cli',
|
||||
process: processMetrics,
|
||||
event_id: crypto.randomUUID(),
|
||||
device_id: session.deviceId,
|
||||
// 注意:不加 resource 字段 — event_logging/batch 是自定义端点,
|
||||
// OTEL resource attributes 由 CLI 通过单独的 OTLP exporter 发送,不在这里
|
||||
};
|
||||
// 合并额外字段(用于特定事件的附加数据)
|
||||
if (extraData) Object.assign(eventData, extraData);
|
||||
return {
|
||||
event_type: 'ClaudeCodeInternalEvent',
|
||||
event_data: eventData,
|
||||
};
|
||||
}
|
||||
|
||||
// 发送遥测到 api.anthropic.com/api/event_logging/batch
|
||||
function sendTelemetryEvents(events, session) {
|
||||
if (!TELEMETRY_ENABLED || events.length === 0) return;
|
||||
|
||||
const body = JSON.stringify({ events });
|
||||
const headers = {
|
||||
'Accept': 'application/json, text/plain, */*',
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': `claude-code/${CLI_VERSION}`,
|
||||
'x-service-name': 'claude-code',
|
||||
'Content-Length': Buffer.byteLength(body),
|
||||
};
|
||||
// 注意:真实 CLI 2.1.84 的 event_logging/batch 不发 traceparent
|
||||
// traceparent 仅在 OTLP exporter(单独通道)中使用,不在这个端点
|
||||
|
||||
const opts = {
|
||||
hostname: 'api.anthropic.com',
|
||||
port: 443,
|
||||
path: '/api/event_logging/batch',
|
||||
method: 'POST',
|
||||
headers,
|
||||
timeout: 10000,
|
||||
};
|
||||
|
||||
const req = https.request(opts, (res) => {
|
||||
res.resume(); // drain
|
||||
log('debug', 'telemetry_sent', { status: res.statusCode, events: events.length });
|
||||
});
|
||||
req.on('error', (err) => {
|
||||
log('debug', 'telemetry_error', { error: err.message });
|
||||
});
|
||||
req.on('timeout', () => req.destroy());
|
||||
req.end(body);
|
||||
}
|
||||
|
||||
// 发送 DataDog 日志
|
||||
function sendDatadogLog(eventName, session, model) {
|
||||
if (!TELEMETRY_ENABLED) return;
|
||||
|
||||
const hostId = session.hostId;
|
||||
const uptime = (Date.now() - session.startTime) / 1000;
|
||||
|
||||
// 复用 Anthropic 事件侧缓存的 process metrics(保持两边数值一致)
|
||||
// 如果没有缓存(首次调用),现场生成
|
||||
let pm;
|
||||
if (session._lastProcessMetrics && Math.abs(session._lastProcessMetrics.uptime - uptime) < 2) {
|
||||
pm = JSON.parse(Buffer.from(session._lastProcessMetrics.raw, 'base64').toString());
|
||||
} else {
|
||||
const baseRss = 180_000_000 + Math.min(uptime * 50_000, 200_000_000);
|
||||
const rss = Math.floor(baseRss + Math.random() * 80_000_000);
|
||||
const heapTotal = Math.floor(rss * 0.6 + Math.random() * 10_000_000);
|
||||
const heapUsed = Math.floor(heapTotal * 0.5 + Math.random() * heapTotal * 0.3);
|
||||
pm = {
|
||||
uptime,
|
||||
rss,
|
||||
heapTotal,
|
||||
heapUsed,
|
||||
external: 14_000_000 + Math.floor(Math.random() * 2_000_000),
|
||||
arrayBuffers: Math.floor(Math.random() * 10_000),
|
||||
constrainedMemory: 0,
|
||||
cpuUsage: {
|
||||
user: Math.floor(uptime * 10_000 + Math.random() * 300_000),
|
||||
system: Math.floor(uptime * 2_000 + Math.random() * 80_000),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const entry = {
|
||||
ddsource: 'nodejs',
|
||||
ddtags: `event:${eventName},arch:${hostId.arch},client_type:cli,model:${model || 'claude-sonnet-4-6'},platform:darwin,user_type:external,version:${CLI_VERSION},version_base:${CLI_VERSION}`,
|
||||
message: eventName,
|
||||
service: 'claude-code',
|
||||
hostname: hostId.hostname,
|
||||
env: 'external',
|
||||
model: model || 'claude-sonnet-4-6',
|
||||
session_id: session.sessionId,
|
||||
user_type: 'external',
|
||||
entrypoint: 'cli',
|
||||
is_interactive: 'true',
|
||||
client_type: 'cli',
|
||||
process_metrics: pm,
|
||||
platform: 'darwin',
|
||||
platform_raw: 'darwin',
|
||||
arch: hostId.arch,
|
||||
node_version: FAKE_NODE_VERSION,
|
||||
version: CLI_VERSION,
|
||||
version_base: CLI_VERSION,
|
||||
build_time: BUILD_TIME,
|
||||
deployment_environment: 'unknown-darwin',
|
||||
vcs: 'git',
|
||||
};
|
||||
|
||||
const body = JSON.stringify([entry]);
|
||||
const opts = {
|
||||
hostname: 'http-intake.logs.us5.datadoghq.com',
|
||||
port: 443,
|
||||
path: '/api/v2/logs',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Accept': 'application/json, text/plain, */*',
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'axios/1.13.6',
|
||||
'dd-api-key': DD_API_KEY,
|
||||
'Content-Length': Buffer.byteLength(body),
|
||||
},
|
||||
timeout: 10000,
|
||||
};
|
||||
|
||||
const req = https.request(opts, (res) => { res.resume(); });
|
||||
req.on('error', () => {});
|
||||
req.on('timeout', () => req.destroy());
|
||||
req.end(body);
|
||||
}
|
||||
|
||||
// 请求前发遥测(模拟 CLI 启动 + 初始化事件)
|
||||
function emitPreRequestTelemetry(reqHeaders, body) {
|
||||
const accountSeed = reqHeaders['x-forwarded-host'] || 'default';
|
||||
const deviceId = generateDeviceId(accountSeed + ':' + (reqHeaders['authorization'] || '').slice(-16));
|
||||
const session = getOrCreateSession(deviceId);
|
||||
session.requestCount++;
|
||||
|
||||
// 从请求体解析真实 model
|
||||
let model = 'claude-sonnet-4-6';
|
||||
try {
|
||||
const parsed = JSON.parse(body.toString());
|
||||
if (parsed.model) model = parsed.model;
|
||||
} catch (_) {}
|
||||
|
||||
const betas = reqHeaders['anthropic-beta'] || 'claude-code-20250219,context-1m-2025-08-07,interleaved-thinking-2025-05-14,redact-thinking-2026-02-12,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24,token-efficient-tools-2026-03-28,advisor-tool-2026-03-01';
|
||||
|
||||
// 首次请求:发完整启动事件序列(匹配真实 CLI 的时序)
|
||||
if (session.requestCount === 1) {
|
||||
const hostId = session.hostId;
|
||||
// 生成递增的时间戳,模拟真实 CLI 启动流程的时间差
|
||||
const baseTime = Date.now();
|
||||
const ts = (offsetMs) => new Date(baseTime + offsetMs).toISOString();
|
||||
|
||||
// 第一批:启动 + 工具检测 + MCP 连接事件
|
||||
const batch1 = [
|
||||
buildEvent('tengu_started', session, model, betas, null, ts(0)),
|
||||
buildEvent('tengu_init', session, model, betas, null, ts(80 + Math.floor(Math.random() * 120))),
|
||||
// tengu_ripgrep_availability — CLI 必发的工具检测事件,版本/路径按账号不同
|
||||
buildEvent('tengu_ripgrep_availability', session, model, betas, {
|
||||
ripgrep_available: true,
|
||||
ripgrep_version: hostId.ripgrepVersion,
|
||||
ripgrep_path: hostId.ripgrepPath,
|
||||
}, ts(200 + Math.floor(Math.random() * 150))),
|
||||
];
|
||||
// MCP 连接事件:数量按账号不同(真实用户配置的 MCP server 数量差异很大)
|
||||
let mcpOffset = 400;
|
||||
const mcpSuccessCount = hostId.mcpServerCount - hostId.mcpFailCount;
|
||||
for (let i = 0; i < hostId.mcpFailCount; i++) {
|
||||
mcpOffset += 100 + Math.floor(Math.random() * 300);
|
||||
batch1.push(buildEvent('tengu_mcp_server_connection_failed', session, model, betas, null, ts(mcpOffset)));
|
||||
}
|
||||
for (let i = 0; i < mcpSuccessCount; i++) {
|
||||
mcpOffset += 200 + Math.floor(Math.random() * 500);
|
||||
batch1.push(buildEvent('tengu_mcp_server_connection_succeeded', session, model, betas, null, ts(mcpOffset)));
|
||||
}
|
||||
|
||||
session.ripgrepReported = true;
|
||||
sendTelemetryEvents(batch1, session);
|
||||
sendDatadogLog('tengu_started', session, model);
|
||||
sendDatadogLog('tengu_init', session, model);
|
||||
|
||||
// 第二批延迟发送(真实 CLI 间隔约 30 秒)
|
||||
setTimeout(() => {
|
||||
const batch2 = [
|
||||
buildEvent('tengu_session_init', session, model, betas),
|
||||
buildEvent('tengu_context_loaded', session, model, betas),
|
||||
];
|
||||
sendTelemetryEvents(batch2, session);
|
||||
}, 25000 + Math.floor(Math.random() * 10000));
|
||||
}
|
||||
|
||||
// 每次请求:发 request_started
|
||||
const events = [
|
||||
buildEvent('tengu_api_request_started', session, model, betas),
|
||||
];
|
||||
sendTelemetryEvents(events, session);
|
||||
}
|
||||
|
||||
// 请求后发遥测
|
||||
function emitPostRequestTelemetry(reqHeaders, statusCode, body) {
|
||||
const accountSeed = reqHeaders['x-forwarded-host'] || 'default';
|
||||
const deviceId = generateDeviceId(accountSeed + ':' + (reqHeaders['authorization'] || '').slice(-16));
|
||||
const session = getOrCreateSession(deviceId);
|
||||
|
||||
let model = 'claude-sonnet-4-6';
|
||||
try {
|
||||
const parsed = JSON.parse(body.toString());
|
||||
if (parsed.model) model = parsed.model;
|
||||
} catch (_) {}
|
||||
|
||||
const betas = reqHeaders['anthropic-beta'] || 'claude-code-20250219,context-1m-2025-08-07,interleaved-thinking-2025-05-14,redact-thinking-2026-02-12,context-management-2025-06-27,prompt-caching-scope-2026-01-05,effort-2025-11-24,token-efficient-tools-2026-03-28,advisor-tool-2026-03-01';
|
||||
|
||||
// 请求完成事件
|
||||
const events = [
|
||||
buildEvent('tengu_api_request_completed', session, model, betas),
|
||||
buildEvent('tengu_conversation_turn_completed', session, model, betas),
|
||||
];
|
||||
sendTelemetryEvents(events, session);
|
||||
sendDatadogLog('tengu_api_request_completed', session, model);
|
||||
|
||||
// 模拟错误遥测(低概率,匹配 TelemetrySafeError)
|
||||
if (statusCode >= 400 && Math.random() < 0.5) {
|
||||
const errorEvent = buildEvent('tengu_api_request_error', session, model, betas, {
|
||||
error_type: 'TelemetrySafeError',
|
||||
error_code: statusCode,
|
||||
error_message: statusCode === 429 ? 'rate_limit_exceeded' :
|
||||
statusCode === 529 ? 'overloaded' :
|
||||
statusCode >= 500 ? 'server_error' : 'client_error',
|
||||
});
|
||||
sendTelemetryEvents([errorEvent], session);
|
||||
}
|
||||
|
||||
// 随机发额外事件(仅使用已知的真实 CLI 事件名)
|
||||
if (Math.random() < 0.3) {
|
||||
setTimeout(() => {
|
||||
const extra = [
|
||||
buildEvent('tengu_tool_use_completed', session, model, betas),
|
||||
];
|
||||
sendTelemetryEvents(extra, session);
|
||||
}, 2000 + Math.floor(Math.random() * 5000));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── H2 session 管理 ────────────────────────────────────
|
||||
// h2Sessions key 改为 host+proxy 组合,避免不同代理的 session 混用
|
||||
function h2SessionKey(host, proxyUrl) {
|
||||
return proxyUrl ? `${host}|${proxyUrl}` : host;
|
||||
}
|
||||
|
||||
async function getOrCreateH2Session(host, proxyUrl) {
|
||||
const key = h2SessionKey(host, proxyUrl);
|
||||
const existing = h2Sessions.get(key);
|
||||
// 检查 session 是否仍然可用:connected 且未关闭
|
||||
// GOAWAY 后 session.connected 变为 false,必须重建
|
||||
if (existing && !existing.closed && !existing.destroyed && existing.connected) return existing;
|
||||
if (existing) {
|
||||
h2Sessions.delete(key);
|
||||
try { existing.close(); } catch (_) {}
|
||||
}
|
||||
|
||||
let session;
|
||||
if (proxyUrl) {
|
||||
// 通过 CONNECT 隧道建立 h2 session(支持 HTTP CONNECT / SOCKS5)
|
||||
const socket = await connectViaProxy(proxyUrl, host, 443);
|
||||
session = http2.connect(`https://${host}`, {
|
||||
createConnection: () => socket,
|
||||
});
|
||||
log('info', 'h2_session_via_proxy', { host, proxy: redactProxyURL(proxyUrl) });
|
||||
} else {
|
||||
session = http2.connect(`https://${host}`);
|
||||
}
|
||||
|
||||
session.on('error', (err) => {
|
||||
log('warn', 'h2_session_error', { host, error: err.message });
|
||||
h2Sessions.delete(key);
|
||||
try { session.close(); } catch (_) {}
|
||||
});
|
||||
session.on('close', () => h2Sessions.delete(key));
|
||||
session.on('goaway', (errorCode) => {
|
||||
log('info', 'h2_goaway', { host, errorCode });
|
||||
h2Sessions.delete(key);
|
||||
try { session.close(); } catch (_) {}
|
||||
});
|
||||
session.setTimeout(IDLE_TIMEOUT, () => { session.close(); h2Sessions.delete(key); });
|
||||
h2Sessions.set(key, session);
|
||||
return session;
|
||||
}
|
||||
|
||||
function waitForConnect(session) {
|
||||
if (session.connected) return Promise.resolve();
|
||||
// session 已断开(GOAWAY / 半关闭),不要等不会来的 connect 事件
|
||||
if (session.closed || session.destroyed) {
|
||||
return Promise.reject(new Error('h2 session already closed'));
|
||||
}
|
||||
return new Promise((resolve, reject) => {
|
||||
const onConnect = () => { clearTimeout(t); cleanup(); resolve(); };
|
||||
const onError = (err) => { clearTimeout(t); cleanup(); reject(err); };
|
||||
const onClose = () => { clearTimeout(t); cleanup(); reject(new Error('h2 session closed before connect')); };
|
||||
const cleanup = () => {
|
||||
session.removeListener('connect', onConnect);
|
||||
session.removeListener('error', onError);
|
||||
session.removeListener('close', onClose);
|
||||
};
|
||||
session.once('connect', onConnect);
|
||||
session.once('error', onError);
|
||||
session.once('close', onClose);
|
||||
const t = setTimeout(() => { cleanup(); reject(new Error('h2 connect timeout')); }, CONNECT_TIMEOUT);
|
||||
});
|
||||
}
|
||||
|
||||
// ─── CONNECT 隧道(HTTP CONNECT + SOCKS5)─────────────────
|
||||
function connectViaProxy(proxyUrl, targetHost, targetPort) {
|
||||
const proxy = new URL(proxyUrl);
|
||||
const scheme = proxy.protocol.replace(':', '').toLowerCase();
|
||||
|
||||
if (scheme === 'socks5' || scheme === 'socks5h') {
|
||||
return connectViaSocks5(proxy, targetHost, parseInt(targetPort, 10));
|
||||
}
|
||||
return connectViaHttpConnect(proxy, targetHost, targetPort);
|
||||
}
|
||||
|
||||
// HTTP CONNECT 隧道
|
||||
function connectViaHttpConnect(proxy, targetHost, targetPort) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const conn = net.connect(parseInt(proxy.port || '80', 10), proxy.hostname, () => {
|
||||
const auth = proxy.username
|
||||
? `Proxy-Authorization: Basic ${Buffer.from(`${decodeURIComponent(proxy.username)}:${decodeURIComponent(proxy.password || '')}`).toString('base64')}\r\n`
|
||||
: '';
|
||||
conn.write(`CONNECT ${targetHost}:${targetPort} HTTP/1.1\r\nHost: ${targetHost}:${targetPort}\r\n${auth}\r\n`);
|
||||
});
|
||||
conn.once('error', reject);
|
||||
conn.setTimeout(CONNECT_TIMEOUT, () => conn.destroy(new Error('CONNECT timeout')));
|
||||
let buf = '';
|
||||
conn.on('data', function onData(chunk) {
|
||||
buf += chunk.toString();
|
||||
const idx = buf.indexOf('\r\n\r\n');
|
||||
if (idx === -1) return;
|
||||
conn.removeListener('data', onData);
|
||||
const code = parseInt(buf.split(' ')[1], 10);
|
||||
if (code === 200) { conn.setTimeout(0); resolve(conn); }
|
||||
else { conn.destroy(); reject(new Error(`CONNECT ${code}`)); }
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// SOCKS5 隧道 (RFC 1928 + RFC 1929 username/password auth)
|
||||
function connectViaSocks5(proxy, targetHost, targetPort) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const conn = net.connect(parseInt(proxy.port || '1080', 10), proxy.hostname);
|
||||
conn.once('error', reject);
|
||||
conn.setTimeout(CONNECT_TIMEOUT, () => conn.destroy(new Error('SOCKS5 timeout')));
|
||||
|
||||
const username = proxy.username ? decodeURIComponent(proxy.username) : '';
|
||||
const password = proxy.password ? decodeURIComponent(proxy.password) : '';
|
||||
const useAuth = !!(username || password);
|
||||
|
||||
let step = 'greeting';
|
||||
|
||||
conn.once('connect', () => {
|
||||
// Step 1: 发送 greeting — 支持的认证方式
|
||||
// 0x00 = 无认证, 0x02 = 用户名/密码
|
||||
if (useAuth) {
|
||||
conn.write(Buffer.from([0x05, 0x02, 0x00, 0x02]));
|
||||
} else {
|
||||
conn.write(Buffer.from([0x05, 0x01, 0x00]));
|
||||
}
|
||||
});
|
||||
|
||||
let pending = Buffer.alloc(0);
|
||||
conn.on('data', function onData(chunk) {
|
||||
pending = Buffer.concat([pending, chunk]);
|
||||
|
||||
if (step === 'greeting') {
|
||||
if (pending.length < 2) return;
|
||||
const ver = pending[0], method = pending[1];
|
||||
if (ver !== 0x05) { conn.destroy(); return reject(new Error(`SOCKS5 bad version: ${ver}`)); }
|
||||
|
||||
if (method === 0x02 && useAuth) {
|
||||
// Step 2: 用户名/密码认证 (RFC 1929)
|
||||
step = 'auth';
|
||||
pending = pending.slice(2);
|
||||
const uBuf = Buffer.from(username, 'utf8');
|
||||
const pBuf = Buffer.from(password, 'utf8');
|
||||
const authBuf = Buffer.alloc(3 + uBuf.length + pBuf.length);
|
||||
authBuf[0] = 0x01; // auth version
|
||||
authBuf[1] = uBuf.length;
|
||||
uBuf.copy(authBuf, 2);
|
||||
authBuf[2 + uBuf.length] = pBuf.length;
|
||||
pBuf.copy(authBuf, 3 + uBuf.length);
|
||||
conn.write(authBuf);
|
||||
} else if (method === 0x00) {
|
||||
// 无需认证,直接发 CONNECT
|
||||
step = 'connect';
|
||||
pending = pending.slice(2);
|
||||
sendSocks5Connect(conn, targetHost, targetPort);
|
||||
} else {
|
||||
conn.destroy();
|
||||
reject(new Error(`SOCKS5 unsupported auth method: ${method}`));
|
||||
}
|
||||
} else if (step === 'auth') {
|
||||
if (pending.length < 2) return;
|
||||
const status = pending[1];
|
||||
if (status !== 0x00) { conn.destroy(); return reject(new Error(`SOCKS5 auth failed: ${status}`)); }
|
||||
step = 'connect';
|
||||
pending = pending.slice(2);
|
||||
sendSocks5Connect(conn, targetHost, targetPort);
|
||||
} else if (step === 'connect') {
|
||||
// 最小响应: VER(1) + REP(1) + RSV(1) + ATYP(1) + ADDR(variable) + PORT(2)
|
||||
if (pending.length < 4) return;
|
||||
const rep = pending[1];
|
||||
const atyp = pending[3];
|
||||
let minLen = 4 + 2; // base + port
|
||||
if (atyp === 0x01) minLen += 4; // IPv4
|
||||
else if (atyp === 0x04) minLen += 16; // IPv6
|
||||
else if (atyp === 0x03 && pending.length > 4) minLen += 1 + pending[4]; // domain
|
||||
else if (atyp === 0x03) return; // 等更多数据
|
||||
if (pending.length < minLen) return;
|
||||
|
||||
conn.removeListener('data', onData);
|
||||
if (rep !== 0x00) { conn.destroy(); return reject(new Error(`SOCKS5 connect failed: rep=${rep}`)); }
|
||||
conn.setTimeout(0);
|
||||
resolve(conn);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function sendSocks5Connect(conn, host, port) {
|
||||
// SOCKS5 CONNECT: VER(05) CMD(01=CONNECT) RSV(00) ATYP ADDR PORT
|
||||
const hostBuf = Buffer.from(host, 'utf8');
|
||||
const buf = Buffer.alloc(4 + 1 + hostBuf.length + 2);
|
||||
buf[0] = 0x05; // version
|
||||
buf[1] = 0x01; // CONNECT
|
||||
buf[2] = 0x00; // reserved
|
||||
buf[3] = 0x03; // domain name
|
||||
buf[4] = hostBuf.length;
|
||||
hostBuf.copy(buf, 5);
|
||||
buf.writeUInt16BE(port, 5 + hostBuf.length);
|
||||
conn.write(buf);
|
||||
}
|
||||
|
||||
// ─── 收集请求体 ──────────────────────────────────────────
|
||||
function collectBody(req) {
|
||||
return new Promise((resolve) => {
|
||||
const chunks = [];
|
||||
req.on('data', (c) => chunks.push(c));
|
||||
req.on('end', () => resolve(Buffer.concat(chunks)));
|
||||
req.on('error', () => resolve(Buffer.concat(chunks)));
|
||||
});
|
||||
}
|
||||
|
||||
// ─── H1 代理 ─────────────────────────────────────────────
|
||||
function sendViaH1(targetHost, method, path, reqHeaders, body, res, savedHeaders, explicitProxy) {
|
||||
return new Promise((resolve) => {
|
||||
const headers = { ...reqHeaders, host: targetHost };
|
||||
['x-forwarded-host', 'connection', 'keep-alive', 'proxy-connection', 'transfer-encoding'].forEach(h => delete headers[h]);
|
||||
delete headers['x-upstream-proxy'];
|
||||
if (body.length > 0) headers['content-length'] = String(body.length);
|
||||
|
||||
const opts = { hostname: targetHost, port: 443, path, method, headers, servername: targetHost, timeout: CONNECT_TIMEOUT };
|
||||
const startTime = Date.now();
|
||||
|
||||
const finish = (requestOpts) => {
|
||||
const proxyReq = https.request(requestOpts);
|
||||
proxyReq.on('response', (proxyRes) => {
|
||||
log('info', 'proxy_response', { host: targetHost, status: proxyRes.statusCode, path, proto: 'h1' });
|
||||
const rh = { ...proxyRes.headers };
|
||||
delete rh['connection']; delete rh['keep-alive'];
|
||||
res.writeHead(proxyRes.statusCode, rh);
|
||||
proxyRes.pipe(res, { end: true });
|
||||
// 请求完成后发遥测
|
||||
if (path.includes('/v1/messages') && savedHeaders) {
|
||||
emitPostRequestTelemetry(savedHeaders, proxyRes.statusCode, body);
|
||||
}
|
||||
resolve('ok');
|
||||
});
|
||||
proxyReq.on('error', (err) => {
|
||||
if (err.message === 'socket hang up' && (Date.now() - startTime) < 2000) {
|
||||
log('info', 'h1_rejected_switching_to_h2', { host: targetHost });
|
||||
h2Hosts.add(targetHost);
|
||||
sendViaH2(targetHost, method, path, reqHeaders, body, res, savedHeaders, false, explicitProxy).then(() => resolve('h2'));
|
||||
return;
|
||||
}
|
||||
log('error', 'h1_error', { error: err.message, host: targetHost, path });
|
||||
if (!res.headersSent) { res.writeHead(502); res.end(JSON.stringify({ error: 'upstream_connection_error' })); }
|
||||
resolve('error');
|
||||
});
|
||||
proxyReq.on('timeout', () => proxyReq.destroy(new Error('timeout')));
|
||||
proxyReq.end(body);
|
||||
};
|
||||
|
||||
// 动态上游代理:使用显式传入的代理地址
|
||||
const upstreamProxy = explicitProxy || '';
|
||||
|
||||
if (upstreamProxy) {
|
||||
connectViaProxy(upstreamProxy, targetHost, 443)
|
||||
.then((socket) => { opts.socket = socket; opts.agent = false; finish(opts); })
|
||||
.catch((err) => { log('error', 'tunnel_failed', { error: err.message, proxy: redactProxyURL(upstreamProxy) }); if (!res.headersSent) { res.writeHead(502); res.end(JSON.stringify({ error: 'upstream_connection_error' })); } resolve('error'); });
|
||||
} else {
|
||||
finish(opts);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ─── H2 代理 ─────────────────────────────────────────────
|
||||
async function sendViaH2(targetHost, method, path, reqHeaders, body, res, savedHeaders, _retried, proxyUrl) {
|
||||
try {
|
||||
const session = await getOrCreateH2Session(targetHost, proxyUrl);
|
||||
await waitForConnect(session);
|
||||
|
||||
const headers = {};
|
||||
const skip = new Set(['host','connection','keep-alive','proxy-connection','transfer-encoding','upgrade','x-forwarded-host','http2-settings']);
|
||||
for (const [k, v] of Object.entries(reqHeaders)) {
|
||||
if (!skip.has(k.toLowerCase())) headers[k] = v;
|
||||
}
|
||||
headers[':method'] = method;
|
||||
headers[':path'] = path;
|
||||
headers[':authority'] = targetHost;
|
||||
headers[':scheme'] = 'https';
|
||||
if (body.length > 0) headers['content-length'] = String(body.length);
|
||||
|
||||
const stream = session.request(headers);
|
||||
let responded = false;
|
||||
|
||||
stream.on('response', (h2h) => {
|
||||
responded = true;
|
||||
const status = h2h[':status'] || 502;
|
||||
const rh = {};
|
||||
for (const [k, v] of Object.entries(h2h)) { if (!k.startsWith(':')) rh[k] = v; }
|
||||
log('info', 'proxy_response', { host: targetHost, status, path, proto: 'h2' });
|
||||
res.writeHead(status, rh);
|
||||
stream.on('data', (c) => res.write(c));
|
||||
stream.on('end', () => res.end());
|
||||
if (path.includes('/v1/messages') && savedHeaders) {
|
||||
emitPostRequestTelemetry(savedHeaders, status);
|
||||
}
|
||||
});
|
||||
|
||||
stream.on('error', (err) => {
|
||||
if (err.message && err.message.includes('NGHTTP2')) {
|
||||
h2Sessions.delete(targetHost);
|
||||
try { session.close(); } catch (_) {}
|
||||
}
|
||||
if (responded) { if (!res.writableEnded) res.end(); return; }
|
||||
log('error', 'h2_error', { error: err.message, host: targetHost, path });
|
||||
if (!res.headersSent) { res.writeHead(502); res.end(JSON.stringify({ error: 'upstream_connection_error' })); }
|
||||
});
|
||||
|
||||
stream.on('close', () => {
|
||||
if (!responded && !res.headersSent) {
|
||||
log('warn', 'h2_no_response', { host: targetHost, path });
|
||||
res.writeHead(502); res.end(JSON.stringify({ error: 'upstream_connection_error' }));
|
||||
} else if (!res.writableEnded) { res.end(); }
|
||||
});
|
||||
|
||||
stream.setTimeout(CONNECT_TIMEOUT, () => stream.close());
|
||||
stream.end(body);
|
||||
} catch (err) {
|
||||
log('error', 'h2_exception', { error: err.message, host: targetHost, retried: !!_retried });
|
||||
h2Sessions.delete(targetHost);
|
||||
// 首次失败时重试一次(用全新 session)
|
||||
if (!_retried && !res.headersSent) {
|
||||
log('info', 'h2_retry_with_fresh_session', { host: targetHost, path });
|
||||
return sendViaH2(targetHost, method, path, reqHeaders, body, res, savedHeaders, true, proxyUrl);
|
||||
}
|
||||
if (!res.headersSent) { res.writeHead(502); res.end(JSON.stringify({ error: err.message })); }
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 请求入口 ─────────────────────────────────────────────
|
||||
async function proxyRequest(req, res) {
|
||||
const targetHost = req.headers['x-forwarded-host'] || UPSTREAM_HOST;
|
||||
log('info', 'proxy_request', { host: targetHost, method: req.method, path: req.url });
|
||||
|
||||
// 保存原始 headers 用于遥测
|
||||
const savedHeaders = { ...req.headers };
|
||||
|
||||
const body = await collectBody(req);
|
||||
|
||||
// 请求前发遥测(仅 /v1/messages 请求)
|
||||
if (req.url.includes('/v1/messages') && TELEMETRY_ENABLED) {
|
||||
emitPreRequestTelemetry(savedHeaders, body);
|
||||
}
|
||||
|
||||
// ── Jitter 注入 ──────────────────────────────────────────────────
|
||||
// 模拟人类编码间歇:80% 快速响应(80-300ms),20% 慢速思考(400-1200ms)
|
||||
// 使用 -log(rand) 指数衰减使延迟尾部更接近真实键盘输入节奏
|
||||
const jitterMs = (() => {
|
||||
if (Math.random() < 0.80) {
|
||||
return Math.floor(80 + (-Math.log(Math.random()) * 90)); // 快:~80-300ms
|
||||
}
|
||||
return Math.floor(400 + Math.random() * 800); // 慢:400-1200ms
|
||||
})();
|
||||
await new Promise(r => setTimeout(r, jitterMs));
|
||||
|
||||
// ── H2 / H1 路由策略 ──────────────────────────────────────────────
|
||||
// H2 现在支持通过 CONNECT 隧道代理,优先为 H2_PREFER_HOSTS 使用 h2。
|
||||
// 有代理时通过 connectViaProxy 建立隧道后再 h2 连接。
|
||||
const upstreamProxy = req.headers['x-upstream-proxy'] || UPSTREAM_PROXY;
|
||||
// 清除内部 header,不传给上游(h2 路径也需要清理)
|
||||
delete req.headers['x-upstream-proxy'];
|
||||
const H2_PREFER_HOSTS = new Set([
|
||||
'api.anthropic.com',
|
||||
'cloudaicompanion.googleapis.com',
|
||||
'generativelanguage.googleapis.com',
|
||||
'cloudcode-pa.googleapis.com',
|
||||
'daily-cloudcode-pa.googleapis.com',
|
||||
]);
|
||||
if (H2_PREFER_HOSTS.has(targetHost) || h2Hosts.has(targetHost)) {
|
||||
await sendViaH2(targetHost, req.method, req.url, req.headers, body, res, savedHeaders, false, upstreamProxy || undefined);
|
||||
} else {
|
||||
await sendViaH1(targetHost, req.method, req.url, req.headers, body, res, savedHeaders, upstreamProxy || undefined);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── HTTP 服务器 ─────────────────────────────────────────
|
||||
const server = http.createServer((req, res) => {
|
||||
if (req.url === HEALTH_PATH) {
|
||||
res.writeHead(200, { 'content-type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
status: 'ok', node: process.version, openssl: process.versions.openssl,
|
||||
uptime: process.uptime(), h2Hosts: [...h2Hosts],
|
||||
telemetry: TELEMETRY_ENABLED, sessions: sessionStates.size,
|
||||
}));
|
||||
return;
|
||||
}
|
||||
proxyRequest(req, res).catch((err) => {
|
||||
log('error', 'unhandled', { error: err.message });
|
||||
if (!res.headersSent) { res.writeHead(500); res.end('internal error'); }
|
||||
});
|
||||
});
|
||||
|
||||
server.timeout = 0;
|
||||
server.keepAliveTimeout = IDLE_TIMEOUT;
|
||||
server.headersTimeout = 60000;
|
||||
server.listen(LISTEN_PORT, LISTEN_HOST, () => {
|
||||
log('info', 'node-tls-proxy started', {
|
||||
listen: `${LISTEN_HOST}:${LISTEN_PORT}`, node: process.version, openssl: process.versions.openssl,
|
||||
telemetry: TELEMETRY_ENABLED,
|
||||
});
|
||||
});
|
||||
|
||||
// 定期清理过期 session(1 小时无活动)
|
||||
setInterval(() => {
|
||||
const now = Date.now();
|
||||
for (const [id, state] of sessionStates) {
|
||||
if (now - state.startTime > 3600_000) sessionStates.delete(id);
|
||||
}
|
||||
}, 300_000);
|
||||
|
||||
let stopping = false;
|
||||
function shutdown(sig) {
|
||||
if (stopping) return; stopping = true;
|
||||
for (const s of h2Sessions.values()) try { s.close(); } catch (_) {}
|
||||
h2Sessions.clear();
|
||||
server.close(() => process.exit(0));
|
||||
setTimeout(() => process.exit(1), 5000);
|
||||
}
|
||||
process.on('SIGTERM', () => shutdown('SIGTERM'));
|
||||
process.on('SIGINT', () => shutdown('SIGINT'));
|
||||
process.on('uncaughtException', (e) => log('error', 'uncaught', { error: e.message }));
|
||||
process.on('unhandledRejection', (r) => log('error', 'rejection', { error: String(r) }));
|
||||
49
backend/cmd/lsworker/main.go
Normal file
49
backend/cmd/lsworker/main.go
Normal file
@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
func main() {
|
||||
server, err := lspool.NewWorkerServerFromEnv()
|
||||
if err != nil {
|
||||
slog.Error("failed to initialize lsworker", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: envOrDefault("LSWORKER_LISTEN_ADDR", "0.0.0.0:18081"),
|
||||
Handler: server.Handler(),
|
||||
ReadHeaderTimeout: 10 * 1e9,
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = httpServer.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
slog.Info("lsworker listening", "addr", httpServer.Addr)
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
slog.Error("lsworker exited with error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func envOrDefault(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@ -79,6 +79,7 @@ func provideCleanup(
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
lsPoolBootstrap *service.LSPoolBootstrapService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
@ -171,6 +172,12 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"LSPoolBootstrapService", func() error {
|
||||
if lsPoolBootstrap != nil {
|
||||
lsPoolBootstrap.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@ -35,6 +36,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 应用实例级指纹覆盖(不同 sub2api 实例可设不同的默认版本号)
|
||||
fpd := configConfig.Gateway.FingerprintDefaults
|
||||
claude.ApplyFingerprintOverrides(fpd.ClaudeCLIVersion, fpd.StainlessPackageVersion, fpd.StainlessRuntimeVersion, fpd.StainlessOS, fpd.StainlessArch)
|
||||
service.ApplyDefaultFingerprintOverrides(fpd.ClaudeCLIVersion, fpd.StainlessPackageVersion, fpd.StainlessRuntimeVersion, fpd.StainlessOS, fpd.StainlessArch)
|
||||
client, err := repository.ProvideEnt(configConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -171,7 +176,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
identityService := service.NewIdentityServiceWithSalt(identityCache, configConfig.Gateway.InstanceSalt)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
@ -241,10 +246,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
|
||||
lsPoolBootstrapService := service.ProvideLSPoolBootstrapService(accountRepository, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, lsPoolBootstrapService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@ -282,6 +288,7 @@ func provideCleanup(
|
||||
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
lsPoolBootstrap *service.LSPoolBootstrapService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
subscriptionExpiry *service.SubscriptionExpiryService,
|
||||
usageCleanup *service.UsageCleanupService,
|
||||
@ -373,6 +380,12 @@ func provideCleanup(
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"LSPoolBootstrapService", func() error {
|
||||
if lsPoolBootstrap != nil {
|
||||
lsPoolBootstrap.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"AccountExpiryService", func() error {
|
||||
accountExpiry.Stop()
|
||||
return nil
|
||||
|
||||
@ -47,6 +47,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
lsPoolBootstrapSvc := service.NewLSPoolBootstrapService(nil, nil, cfg)
|
||||
|
||||
cleanup := provideCleanup(
|
||||
nil, // entClient
|
||||
@ -60,6 +61,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
&service.SoraMediaCleanupService{},
|
||||
schedulerSnapshotSvc,
|
||||
tokenRefreshSvc,
|
||||
lsPoolBootstrapSvc,
|
||||
accountExpirySvc,
|
||||
subscriptionExpirySvc,
|
||||
&service.UsageCleanupService{},
|
||||
|
||||
@ -379,6 +379,10 @@ type GatewayConfig struct {
|
||||
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
|
||||
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
|
||||
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
|
||||
// AntigravityLSWorker: LS worker 容器控制平面配置
|
||||
AntigravityLSWorker GatewayAntigravityLSWorkerConfig `mapstructure:"antigravity_ls_worker"`
|
||||
// NodeTLSProxy: Node.js TLS 代理配置
|
||||
NodeTLSProxy NodeTLSProxyConfig `mapstructure:"node_tls_proxy"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@ -456,6 +460,18 @@ type GatewayConfig struct {
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
|
||||
// InstanceSalt: 实例级隔离盐值
|
||||
// 用于 user_id 重写和 session hash 的种子混淆,
|
||||
// 不同 sub2api 实例设置不同的 salt,确保相同输入产生不同输出。
|
||||
// 为空时使用默认行为(无 salt),建议生产环境必须配置。
|
||||
// 生成方法: openssl rand -hex 32
|
||||
InstanceSalt string `mapstructure:"instance_salt"`
|
||||
|
||||
// FingerprintDefaults: 指纹默认值覆盖
|
||||
// 允许每个实例配置不同的 Claude CLI 版本号,与其他 sub2api 实例区分。
|
||||
// 为空时使用代码内置默认值。
|
||||
FingerprintDefaults FingerprintDefaultsConfig `mapstructure:"fingerprint_defaults"`
|
||||
|
||||
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
|
||||
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
|
||||
|
||||
@ -469,6 +485,16 @@ type GatewayConfig struct {
|
||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||
}
|
||||
|
||||
type GatewayAntigravityLSWorkerConfig struct {
|
||||
Image string `mapstructure:"image"`
|
||||
Network string `mapstructure:"network"`
|
||||
DockerSocket string `mapstructure:"docker_socket"`
|
||||
IdleTTL time.Duration `mapstructure:"idle_ttl"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
StartupTimeout time.Duration `mapstructure:"startup_timeout"`
|
||||
RequestTimeout time.Duration `mapstructure:"request_timeout"`
|
||||
}
|
||||
|
||||
// UserMessageQueueConfig 用户消息串行队列配置
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送
|
||||
type UserMessageQueueConfig struct {
|
||||
@ -645,6 +671,23 @@ type SoraModelFiltersConfig struct {
|
||||
HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
|
||||
}
|
||||
|
||||
// NodeTLSProxyConfig Node.js TLS 代理配置
|
||||
// 通过本地 Node.js 进程转发 HTTPS 请求,利用原生 TLS 栈产生真实 JA3 指纹
|
||||
type NodeTLSProxyConfig struct {
|
||||
// Enabled: 全局开关
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// ListenPort: Node.js 代理监听端口
|
||||
ListenPort int `mapstructure:"listen_port"`
|
||||
// ListenHost: Node.js 代理监听地址(Docker 内用服务名,裸机用 127.0.0.1)
|
||||
ListenHost string `mapstructure:"listen_host"`
|
||||
// HealthPath: 健康检查路径
|
||||
HealthPath string `mapstructure:"health_path"`
|
||||
// UpstreamHost: 默认上游主机
|
||||
UpstreamHost string `mapstructure:"upstream_host"`
|
||||
// ProxyHosts: 允许走代理的主机白名单,为空时仅代理 api.anthropic.com
|
||||
ProxyHosts []string `mapstructure:"proxy_hosts"`
|
||||
}
|
||||
|
||||
// TLSFingerprintConfig TLS指纹伪装配置
|
||||
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
|
||||
type TLSFingerprintConfig struct {
|
||||
@ -685,6 +728,23 @@ type TLSProfileConfig struct {
|
||||
Extensions []uint16 `mapstructure:"extensions"`
|
||||
}
|
||||
|
||||
// FingerprintDefaultsConfig 指纹默认值配置
|
||||
// 允许每个 sub2api 实例设置不同的默认指纹值,与其他实例区分。
|
||||
// 所有字段为空时使用代码内置默认值。
|
||||
type FingerprintDefaultsConfig struct {
|
||||
// ClaudeCLIVersion: Claude CLI 版本号(如 "2.1.81"),
|
||||
// 最终 User-Agent 为 "claude-cli/{version} (external, cli)"
|
||||
ClaudeCLIVersion string `mapstructure:"claude_cli_version"`
|
||||
// StainlessPackageVersion: @anthropic-ai/sdk 版本(如 "0.80.0")
|
||||
StainlessPackageVersion string `mapstructure:"stainless_package_version"`
|
||||
// StainlessRuntimeVersion: Node.js 版本(如 "v24.13.0")
|
||||
StainlessRuntimeVersion string `mapstructure:"stainless_runtime_version"`
|
||||
// StainlessOS: 操作系统(如 "Linux", "Darwin")
|
||||
StainlessOS string `mapstructure:"stainless_os"`
|
||||
// StainlessArch: 架构(如 "arm64", "x64")
|
||||
StainlessArch string `mapstructure:"stainless_arch"`
|
||||
}
|
||||
|
||||
// GatewaySchedulingConfig accounts scheduling configuration.
|
||||
type GatewaySchedulingConfig struct {
|
||||
// 粘性会话排队配置
|
||||
@ -1278,6 +1338,15 @@ func setDefaults() {
|
||||
|
||||
// RateLimit
|
||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||
|
||||
// Gateway LS worker
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.image", "weishaw/sub2api-lsworker:latest")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.network", "sub2api-network")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.docker_socket", "unix:///var/run/docker.sock")
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.idle_ttl", 15*time.Minute)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.max_active", 50)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.startup_timeout", 45*time.Second)
|
||||
viper.SetDefault("gateway.antigravity_ls_worker.request_timeout", 60*time.Second)
|
||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
|
||||
|
||||
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
|
||||
@ -1463,6 +1532,21 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60)
|
||||
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
|
||||
// Node.js TLS Proxy 默认值
|
||||
// 注意:必须显式 BindEnv,因为 viper.Unmarshal 对嵌套 struct 的 AutomaticEnv
|
||||
// 支持有缺陷——仅 SetDefault 注册的 key 在 config.yaml 缺少对应 section 时,
|
||||
// 环境变量不会被合并到 Unmarshal 结果中。
|
||||
viper.SetDefault("gateway.node_tls_proxy.enabled", false)
|
||||
viper.SetDefault("gateway.node_tls_proxy.listen_port", 3456)
|
||||
viper.SetDefault("gateway.node_tls_proxy.listen_host", "127.0.0.1")
|
||||
viper.SetDefault("gateway.node_tls_proxy.health_path", "/__health")
|
||||
viper.SetDefault("gateway.node_tls_proxy.upstream_host", "api.anthropic.com")
|
||||
_ = viper.BindEnv("gateway.node_tls_proxy.enabled", "GATEWAY_NODE_TLS_PROXY_ENABLED")
|
||||
_ = viper.BindEnv("gateway.node_tls_proxy.listen_port", "GATEWAY_NODE_TLS_PROXY_LISTEN_PORT")
|
||||
_ = viper.BindEnv("gateway.node_tls_proxy.listen_host", "GATEWAY_NODE_TLS_PROXY_LISTEN_HOST")
|
||||
_ = viper.BindEnv("gateway.node_tls_proxy.health_path", "GATEWAY_NODE_TLS_PROXY_HEALTH_PATH")
|
||||
_ = viper.BindEnv("gateway.node_tls_proxy.upstream_host", "GATEWAY_NODE_TLS_PROXY_UPSTREAM_HOST")
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// Sora 直连配置
|
||||
|
||||
@ -515,7 +515,7 @@ func validateDataProxy(item DataProxy) error {
|
||||
return errors.New("proxy port is invalid")
|
||||
}
|
||||
switch item.Protocol {
|
||||
case "http", "https", "socks5", "socks5h":
|
||||
case "http", "socks5", "socks5h":
|
||||
default:
|
||||
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
|
||||
}
|
||||
|
||||
@ -1453,6 +1453,12 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
|
||||
req = GenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
if req.ProxyID != nil {
|
||||
slog.Info("generate_setup_token_url", "proxy_id", *req.ProxyID)
|
||||
} else {
|
||||
slog.Info("generate_setup_token_url", "proxy_id", nil)
|
||||
}
|
||||
|
||||
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -1500,6 +1506,12 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProxyID != nil {
|
||||
slog.Info("exchange_setup_token_code", "session_id", req.SessionID, "proxy_id", *req.ProxyID)
|
||||
} else {
|
||||
slog.Info("exchange_setup_token_code", "session_id", req.SessionID, "proxy_id", nil)
|
||||
}
|
||||
|
||||
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
|
||||
138
backend/internal/handler/admin/debug_log_handler.go
Normal file
138
backend/internal/handler/admin/debug_log_handler.go
Normal file
@ -0,0 +1,138 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DebugLogHandler provides admin endpoints to control gateway debug logging.
|
||||
type DebugLogHandler struct {
|
||||
debugLogger *service.GatewayDebugLogger
|
||||
}
|
||||
|
||||
func NewDebugLogHandler(debugLogger *service.GatewayDebugLogger) *DebugLogHandler {
|
||||
return &DebugLogHandler{debugLogger: debugLogger}
|
||||
}
|
||||
|
||||
// GetStatus returns whether debug logging is enabled.
|
||||
func (h *DebugLogHandler) GetStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": h.debugLogger.IsEnabled(),
|
||||
})
|
||||
}
|
||||
|
||||
// Enable turns on debug logging.
|
||||
func (h *DebugLogHandler) Enable(c *gin.Context) {
|
||||
h.debugLogger.Enable()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": true,
|
||||
"message": "gateway debug logging enabled",
|
||||
})
|
||||
}
|
||||
|
||||
// Disable turns off debug logging.
|
||||
func (h *DebugLogHandler) Disable(c *gin.Context) {
|
||||
h.debugLogger.Disable()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": false,
|
||||
"message": "gateway debug logging disabled",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs returns recent debug logs with pagination.
|
||||
func (h *DebugLogHandler) ListLogs(c *gin.Context) {
|
||||
db := h.debugLogger.DB()
|
||||
if db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
accountID := c.Query("account_id")
|
||||
eventType := c.Query("event_type")
|
||||
|
||||
query := `SELECT id, upstream_request_id, account_id, account_email, account_platform,
|
||||
event_type, method, full_url, request_headers, request_body, request_size,
|
||||
response_status, response_headers, response_body_preview, response_size,
|
||||
model_requested, model_upstream, is_stream, duration_ms, tls_profile,
|
||||
error_message, created_at
|
||||
FROM gateway_debug_logs WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if accountID != "" {
|
||||
query += " AND account_id = $" + strconv.Itoa(argIdx)
|
||||
args = append(args, accountID)
|
||||
argIdx++
|
||||
}
|
||||
if eventType != "" {
|
||||
query += " AND event_type = $" + strconv.Itoa(argIdx)
|
||||
args = append(args, eventType)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT $" + strconv.Itoa(argIdx)
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.QueryContext(c.Request.Context(), query, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type logRow struct {
|
||||
ID int64 `json:"id"`
|
||||
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
AccountEmail *string `json:"account_email"`
|
||||
AccountPlatform *string `json:"account_platform"`
|
||||
EventType string `json:"event_type"`
|
||||
Method *string `json:"method"`
|
||||
FullURL *string `json:"full_url"`
|
||||
RequestHeaders *string `json:"request_headers"`
|
||||
RequestBody *string `json:"request_body"`
|
||||
RequestSize *int `json:"request_size"`
|
||||
ResponseStatus *int `json:"response_status"`
|
||||
ResponseHeaders *string `json:"response_headers"`
|
||||
ResponseBodyPreview *string `json:"response_body_preview"`
|
||||
ResponseSize *int `json:"response_size"`
|
||||
ModelRequested *string `json:"model_requested"`
|
||||
ModelUpstream *string `json:"model_upstream"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
TLSProfile *string `json:"tls_profile"`
|
||||
ErrorMessage *string `json:"error_message"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
var results []logRow
|
||||
for rows.Next() {
|
||||
var r logRow
|
||||
if err := rows.Scan(
|
||||
&r.ID, &r.UpstreamRequestID, &r.AccountID, &r.AccountEmail, &r.AccountPlatform,
|
||||
&r.EventType, &r.Method, &r.FullURL, &r.RequestHeaders, &r.RequestBody, &r.RequestSize,
|
||||
&r.ResponseStatus, &r.ResponseHeaders, &r.ResponseBodyPreview, &r.ResponseSize,
|
||||
&r.ModelRequested, &r.ModelUpstream, &r.IsStream, &r.DurationMs, &r.TLSProfile,
|
||||
&r.ErrorMessage, &r.CreatedAt,
|
||||
); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": results,
|
||||
"count": len(results),
|
||||
})
|
||||
}
|
||||
@ -27,7 +27,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
|
||||
// CreateProxyRequest represents create proxy request
|
||||
type CreateProxyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@ -37,7 +37,7 @@ type CreateProxyRequest struct {
|
||||
// UpdateProxyRequest represents update proxy request
|
||||
type UpdateProxyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@ -299,7 +299,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
|
||||
@ -440,7 +440,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
reqBody.Metadata.IDEVersion = "1.20.6"
|
||||
reqBody.Metadata.IDEVersion = "1.107.0"
|
||||
reqBody.Metadata.IDEName = "antigravity"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -49,11 +50,11 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||
var defaultUserAgentVersion = "1.20.5"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.107.0
|
||||
var defaultUserAgentVersion = "1.107.0"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
// defaultClientSecret 必须通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret string
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
@ -66,12 +67,16 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为)
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if secret := strings.TrimSpace(os.Getenv(AntigravityOAuthClientSecretEnv)); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
return secret, nil
|
||||
}
|
||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
19
backend/internal/pkg/antigravity/oauth_runtime_env_test.go
Normal file
19
backend/internal/pkg/antigravity/oauth_runtime_env_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetClientSecret_ReadsRuntimeEnvironment(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "runtime-secret")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret returned error: %v", err)
|
||||
}
|
||||
if secret != "runtime-secret" {
|
||||
t.Fatalf("unexpected secret: got %q want %q", secret, "runtime-secret")
|
||||
}
|
||||
}
|
||||
@ -39,6 +39,34 @@ func generateStableSessionID(contents []GeminiContent) string {
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it.
|
||||
// preferredSessionID wins; otherwise we derive a stable value from the first user turn.
|
||||
func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(preferredSessionID)
|
||||
if sessionID == "" {
|
||||
var req GeminiRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = generateStableSessionID(req.Contents)
|
||||
}
|
||||
if sessionID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
payload["sessionId"] = sessionID
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
type TransformOptions struct {
|
||||
EnableIdentityPatch bool
|
||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||
|
||||
@ -8,6 +8,43 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureGeminiRequestSessionID(t *testing.T) {
|
||||
t.Run("prefers provided session id", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-from-header", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("keeps existing session id", func(t *testing.T) {
|
||||
body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-in-body", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("derives stable fallback from contents", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
first, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
second, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var firstPayload map[string]any
|
||||
var secondPayload map[string]any
|
||||
require.NoError(t, json.Unmarshal(first, &firstPayload))
|
||||
require.NoError(t, json.Unmarshal(second, &secondPayload))
|
||||
require.NotEmpty(t, firstPayload["sessionId"])
|
||||
require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@ -3,15 +3,27 @@ package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
|
||||
// DefaultCLIVersion 是当前模拟的 Claude CLI 版本
|
||||
const DefaultCLIVersion = "2.1.88"
|
||||
|
||||
// Beta header 常量
|
||||
const (
|
||||
BetaOAuth = "oauth-2025-04-20"
|
||||
BetaClaudeCode = "claude-code-20250219"
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
BetaContext1M = "context-1m-2025-08-07"
|
||||
BetaFastMode = "fast-mode-2026-02-01"
|
||||
BetaRedactThinking = "redact-thinking-2026-02-12"
|
||||
BetaContextManagement = "context-management-2025-06-27"
|
||||
BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
|
||||
BetaEffort = "effort-2025-11-24"
|
||||
BetaTaskBudgets = "task-budgets-2026-03-13"
|
||||
BetaTokenEfficientTools = "token-efficient-tools-2026-03-28"
|
||||
BetaStructuredOutputs = "structured-outputs-2025-12-15"
|
||||
BetaAdvisor = "advisor-tool-2026-03-01"
|
||||
BetaWebSearch = "web-search-2025-03-05"
|
||||
)
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
@ -19,7 +31,7 @@ const (
|
||||
var DroppedBetas = []string{}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||
//
|
||||
@ -27,40 +39,64 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
||||
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||
// request as a non-Claude-Code API request.
|
||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaContext1M + "," + BetaRedactThinking + "," + BetaContextManagement + "," + BetaPromptCachingScope + "," + BetaEffort
|
||||
|
||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + "," + BetaContextManagement
|
||||
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking + "," + BetaEffort
|
||||
|
||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + "," + BetaContext1M + "," + BetaEffort + "," + BetaPromptCachingScope
|
||||
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking + "," + BetaEffort
|
||||
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
|
||||
"User-Agent": "claude-cli/2.1.22 (external, cli)",
|
||||
"User-Agent": "claude-cli/" + DefaultCLIVersion + " (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.70.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Package-Version": "0.74.0",
|
||||
"X-Stainless-OS": "MacOS",
|
||||
"X-Stainless-Arch": "arm64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v24.13.0",
|
||||
"X-Stainless-Runtime-Version": "v24.3.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "600",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
|
||||
// ApplyFingerprintOverrides 用配置覆盖默认指纹值(每个实例可设不同值)
|
||||
// cliVersion: Claude CLI 版本(如 "2.1.81")
|
||||
// pkgVersion: SDK 版本(如 "0.80.0")
|
||||
// runtimeVersion: Node.js 版本(如 "v24.13.0")
|
||||
// os_: 操作系统(如 "Linux")
|
||||
// arch: 架构(如 "arm64")
|
||||
func ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) {
|
||||
if cliVersion != "" {
|
||||
DefaultHeaders["User-Agent"] = "claude-cli/" + cliVersion + " (external, cli)"
|
||||
}
|
||||
if pkgVersion != "" {
|
||||
DefaultHeaders["X-Stainless-Package-Version"] = pkgVersion
|
||||
}
|
||||
if runtimeVersion != "" {
|
||||
DefaultHeaders["X-Stainless-Runtime-Version"] = runtimeVersion
|
||||
}
|
||||
if os_ != "" {
|
||||
DefaultHeaders["X-Stainless-OS"] = os_
|
||||
}
|
||||
if arch != "" {
|
||||
DefaultHeaders["X-Stainless-Arch"] = arch
|
||||
}
|
||||
}
|
||||
|
||||
// Model 表示一个 Claude 模型
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
|
||||
@ -35,13 +35,11 @@ const (
|
||||
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
|
||||
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
|
||||
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
// GeminiCLIOAuthClientID is the public OAuth client ID used by Google Gemini CLI.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
// The secret MUST be provided via this env var — no hardcoded fallback.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@ -170,11 +170,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
|
||||
if secret == "" {
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
var secret string
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
if secret == "" {
|
||||
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
|
||||
|
||||
@ -408,10 +408,10 @@ func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||
func TestBuildAuthorizationURL_RequiresBuiltinSecretEnv(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
@ -419,11 +419,11 @@ func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) {
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err)
|
||||
if err == nil {
|
||||
t.Fatal("BuildAuthorizationURL() 应在未配置内置 secret 环境变量时报错")
|
||||
}
|
||||
if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) {
|
||||
t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL)
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Fatalf("错误消息应提示缺少 %s: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -686,18 +686,15 @@ func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
func TestEffectiveOAuthConfig_RequiresEnvSecret(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err)
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Fatal("未配置环境变量时应报错,而不是回退到仓库内置 secret")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientSecret) == "" {
|
||||
t.Error("ClientSecret 不应为空")
|
||||
}
|
||||
if cfg.ClientID != GeminiCLIOAuthClientID {
|
||||
t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID)
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Fatalf("错误消息应提示缺少 %s: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
13
backend/internal/pkg/lspool/backend.go
Normal file
13
backend/internal/pkg/lspool/backend.go
Normal file
@ -0,0 +1,13 @@
|
||||
package lspool
|
||||
|
||||
import "time"
|
||||
|
||||
// Backend is the control-plane abstraction used by the HTTP upstream wrapper.
|
||||
// It may be backed by a local in-process Pool or by remote LS workers.
|
||||
type Backend interface {
|
||||
GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error)
|
||||
SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time)
|
||||
SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32)
|
||||
Stats() map[string]any
|
||||
Close()
|
||||
}
|
||||
94
backend/internal/pkg/lspool/global.go
Normal file
94
backend/internal/pkg/lspool/global.go
Normal file
@ -0,0 +1,94 @@
|
||||
// Package lspool provides LS-mode integration for the antigravity gateway.
|
||||
//
|
||||
// When LS mode is enabled (via ANTIGRAVITY_LS_MODE=true), requests to
|
||||
// streamGenerateContent are routed through a real Language Server instance
|
||||
// instead of directly to cloudcode-pa. This provides:
|
||||
//
|
||||
// - Authentic TLS fingerprint (Google's own Go binary)
|
||||
// - Real session management and Heartbeat
|
||||
// - Indistinguishable from a real IDE instance
|
||||
//
|
||||
// To enable: set environment variable ANTIGRAVITY_LS_MODE=true
|
||||
// To configure: set ANTIGRAVITY_APP_ROOT to the AntiGravity.app path
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
globalBackend Backend
|
||||
globalPoolOnce sync.Once
|
||||
lsModeEnabled bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
lsModeEnabled = os.Getenv("ANTIGRAVITY_LS_MODE") == "true"
|
||||
}
|
||||
|
||||
// IsLSModeEnabled returns whether LS mode is active
|
||||
func IsLSModeEnabled() bool {
|
||||
return lsModeEnabled
|
||||
}
|
||||
|
||||
const (
|
||||
LSStrategyDirect = "direct"
|
||||
LSStrategyJSParity = "js-parity"
|
||||
)
|
||||
|
||||
// CurrentLSStrategy returns the active LS routing strategy.
|
||||
// Unknown values are treated as "direct" for safety.
|
||||
func CurrentLSStrategy() string {
|
||||
switch strings.ToLower(strings.TrimSpace(os.Getenv("ANTIGRAVITY_LS_STRATEGY"))) {
|
||||
case "", LSStrategyDirect:
|
||||
return LSStrategyDirect
|
||||
case LSStrategyJSParity:
|
||||
return LSStrategyJSParity
|
||||
default:
|
||||
return LSStrategyDirect
|
||||
}
|
||||
}
|
||||
|
||||
// GlobalPool returns the singleton LS pool instance
|
||||
// Creates it on first call if LS mode is enabled
|
||||
func GlobalPool(cfg *config.Config) Backend {
|
||||
if !lsModeEnabled {
|
||||
return nil
|
||||
}
|
||||
globalPoolOnce.Do(func() {
|
||||
manager, err := NewWorkerManagerFromConfig(cfg)
|
||||
if err != nil {
|
||||
slog.Default().Error("failed to initialize LS worker manager", "err", err)
|
||||
return
|
||||
}
|
||||
globalBackend = manager
|
||||
})
|
||||
return globalBackend
|
||||
}
|
||||
|
||||
// Shutdown closes the global pool
|
||||
func Shutdown() {
|
||||
if globalBackend != nil {
|
||||
globalBackend.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// StatusInfo returns the current LS pool status for diagnostics
|
||||
func StatusInfo() map[string]any {
|
||||
info := map[string]any{
|
||||
"ls_mode_enabled": lsModeEnabled,
|
||||
"build": "enhanced",
|
||||
"user_agent": "antigravity/1.107.0",
|
||||
}
|
||||
if lsModeEnabled && globalBackend != nil {
|
||||
stats := globalBackend.Stats()
|
||||
info["pool_total"] = stats["total"]
|
||||
info["pool_active"] = stats["active"]
|
||||
}
|
||||
return info
|
||||
}
|
||||
864
backend/internal/pkg/lspool/integration_test.go
Normal file
864
backend/internal/pkg/lspool/integration_test.go
Normal file
@ -0,0 +1,864 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func readConnectFrame(r io.Reader) ([]byte, error) {
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadLen := binary.BigEndian.Uint32(header[1:5])
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeProtoBytesField(data []byte, targetField int) []byte {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
switch wireType {
|
||||
case 0:
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
case 2:
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
i += n
|
||||
if i+int(length) > len(data) {
|
||||
return nil
|
||||
}
|
||||
if fieldNum == targetField {
|
||||
return data[i : i+int(length)]
|
||||
}
|
||||
i += int(length)
|
||||
case 1:
|
||||
i += 8
|
||||
case 5:
|
||||
i += 4
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeProtoBytesFields(data []byte, targetField int) [][]byte {
|
||||
var values [][]byte
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
switch wireType {
|
||||
case 0:
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
case 2:
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return values
|
||||
}
|
||||
i += n
|
||||
if i+int(length) > len(data) {
|
||||
return values
|
||||
}
|
||||
if fieldNum == targetField {
|
||||
values = append(values, append([]byte(nil), data[i:i+int(length)]...))
|
||||
}
|
||||
i += int(length)
|
||||
case 1:
|
||||
i += 8
|
||||
case 5:
|
||||
i += 4
|
||||
default:
|
||||
return values
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func decodeTopicRows(topic []byte) map[string]string {
|
||||
rows := make(map[string]string)
|
||||
for _, entry := range decodeProtoBytesFields(topic, 1) {
|
||||
key := decodeProtoString(entry, 1)
|
||||
row := decodeProtoBytesField(entry, 2)
|
||||
rows[key] = decodeProtoString(row, 1)
|
||||
}
|
||||
return rows
|
||||
}
|
||||
|
||||
func requireBase64PrimitiveValue(t *testing.T, got string, want []byte) {
|
||||
t.Helper()
|
||||
decoded, err := base64.StdEncoding.DecodeString(got)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, decoded)
|
||||
}
|
||||
|
||||
// TestMockExtensionServerTokenInjection verifies the token injection flow:
|
||||
// Extension → MockExtensionServer → LS subscribes uss-oauth → gets OAuthTokenInfo
|
||||
func TestMockExtensionServerTokenInjection(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
// 1. Set token for an account
|
||||
srv.SetToken("account-1", &TokenInfo{
|
||||
AccessToken: "ya29.test-access-token",
|
||||
RefreshToken: "1//test-refresh-token",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
|
||||
// 2. Verify token is stored
|
||||
srv.mu.RLock()
|
||||
info, ok := srv.tokens["account-1"]
|
||||
srv.mu.RUnlock()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ya29.test-access-token", info.AccessToken)
|
||||
require.Equal(t, "1//test-refresh-token", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
|
||||
// 3. Simulate LS subscribing to uss-oauth (HTTP request to mock server)
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-oauth"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
// The stream handler will block, so run in background and cancel after we confirm connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
require.Equal(t, "application/connect+proto", resp.Header.Get("Content-Type"))
|
||||
|
||||
// Read the first envelope frame (initial state)
|
||||
header := make([]byte, 5)
|
||||
n, readErr := resp.Body.Read(header)
|
||||
if readErr == nil && n == 5 {
|
||||
require.Equal(t, byte(0x00), header[0], "first byte should be 0x00 (data frame)")
|
||||
t.Logf("Received initial state frame: flags=%d, payload_len=%d", header[0], header[1:5])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMockExtensionServerCSRF verifies CSRF token validation
|
||||
func TestMockExtensionServerCSRF(t *testing.T) {
|
||||
csrf := "correct-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
base := fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/Heartbeat", srv.Port())
|
||||
|
||||
// Wrong CSRF → 403
|
||||
req, _ := http.NewRequest("POST", base, nil)
|
||||
req.Header.Set("x-codeium-csrf-token", "wrong-csrf")
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 403, resp.StatusCode)
|
||||
|
||||
// Correct CSRF → 200
|
||||
req2, _ := http.NewRequest("POST", base, nil)
|
||||
req2.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req2.Header.Set("Content-Type", "application/proto")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
require.NoError(t, err)
|
||||
defer resp2.Body.Close()
|
||||
require.Equal(t, 200, resp2.StatusCode)
|
||||
}
|
||||
|
||||
// TestMockExtensionServerGetSecretValue verifies the fallback token path
|
||||
func TestMockExtensionServerGetSecretValue(t *testing.T) {
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetToken("acc", &TokenInfo{AccessToken: "ya29.secret-token"})
|
||||
|
||||
// GetSecretValue should return the token
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/GetSecretValue", srv.Port()),
|
||||
nil)
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/proto")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoProto verifies the proto encoding matches real IDE format
|
||||
func TestOAuthTokenInfoProto(t *testing.T) {
|
||||
expiry := time.Date(2026, 3, 29, 19, 0, 0, 0, time.UTC)
|
||||
bin := buildOAuthTokenInfoBinary("ya29.test", "1//refresh", expiry)
|
||||
|
||||
// Verify fields are present by checking proto wire format
|
||||
require.True(t, len(bin) > 0, "proto should not be empty")
|
||||
|
||||
// Field 1 (access_token): tag=0x0a, value="ya29.test"
|
||||
require.Contains(t, string(bin), "ya29.test")
|
||||
// Field 2 (token_type): tag=0x12, value="Bearer"
|
||||
require.Contains(t, string(bin), "Bearer")
|
||||
// Field 3 (refresh_token): tag=0x1a, value="1//refresh"
|
||||
require.Contains(t, string(bin), "1//refresh")
|
||||
|
||||
// Without refresh_token
|
||||
binNoRefresh := buildOAuthTokenInfoBinary("ya29.test", "", expiry)
|
||||
require.NotContains(t, string(binNoRefresh), "1//refresh")
|
||||
}
|
||||
|
||||
// TestOAuthTokenInfoWithRealExpiry verifies expiry uses real time, not hardcoded
|
||||
func TestOAuthTokenInfoWithRealExpiry(t *testing.T) {
|
||||
future := time.Now().Add(2 * time.Hour)
|
||||
bin := buildOAuthTokenInfoBinary("token", "refresh", future)
|
||||
|
||||
// Zero expiry should default to ~1h
|
||||
binZero := buildOAuthTokenInfoBinary("token", "refresh", time.Time{})
|
||||
|
||||
// They should be different lengths or content (different expiry timestamps)
|
||||
// Both should be valid (non-empty)
|
||||
require.True(t, len(bin) > 0)
|
||||
require.True(t, len(binZero) > 0)
|
||||
}
|
||||
|
||||
// TestUSSTopicWithOAuth verifies the full USS topic proto structure
|
||||
func TestUSSTopicWithOAuth(t *testing.T) {
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
topic := buildUSSTopicWithOAuth("ya29.access", "1//refresh", expiry)
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
// The topic should contain the sentinel key
|
||||
require.Contains(t, string(topic), "oauthTokenInfoSentinelKey")
|
||||
}
|
||||
|
||||
func TestUSSTopicWithModelCredits(t *testing.T) {
|
||||
available := int32(123)
|
||||
minimum := int32(50)
|
||||
topic := buildUSSTopicWithModelCredits(&ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
require.True(t, len(topic) > 0)
|
||||
require.Contains(t, string(topic), useAICreditsSentinelKey)
|
||||
require.Contains(t, string(topic), availableCreditsSentinelKey)
|
||||
require.Contains(t, string(topic), minimumCreditAmountForUsageKey)
|
||||
|
||||
rows := decodeTopicRows(topic)
|
||||
requireBase64PrimitiveValue(t, rows[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
||||
requireBase64PrimitiveValue(t, rows[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
||||
requireBase64PrimitiveValue(t, rows[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
||||
}
|
||||
|
||||
func TestMockExtensionServerModelCreditsDynamicUpdate(t *testing.T) {
|
||||
csrf := "test-csrf-token"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{})
|
||||
|
||||
req, _ := http.NewRequest("POST",
|
||||
fmt.Sprintf("http://127.0.0.1:%d/exa.extension_server_pb.ExtensionServerService/SubscribeToUnifiedStateSyncTopic", srv.Port()),
|
||||
bytes.NewReader(frameConnectMessage(encodeProtoString(1, "uss-modelCredits"))))
|
||||
req.Header.Set("x-codeium-csrf-token", csrf)
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
// Drain the initial_state frame first.
|
||||
_, err = readConnectFrame(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
srv.SetModelCredits("account-1", &ModelCreditsInfo{
|
||||
UseAICredits: true,
|
||||
AvailableCredits: &available,
|
||||
MinimumCreditAmountForUsage: &minimum,
|
||||
})
|
||||
|
||||
values := make(map[string]string, 3)
|
||||
for len(values) < 3 {
|
||||
frame, readErr := readConnectFrame(resp.Body)
|
||||
require.NoError(t, readErr)
|
||||
applied := decodeProtoBytesField(frame, 2)
|
||||
require.NotEmpty(t, applied)
|
||||
key := decodeProtoString(applied, 1)
|
||||
row := decodeProtoBytesField(applied, 2)
|
||||
values[key] = decodeProtoString(row, 1)
|
||||
}
|
||||
|
||||
require.Contains(t, values, useAICreditsSentinelKey)
|
||||
require.Contains(t, values, availableCreditsSentinelKey)
|
||||
require.Contains(t, values, minimumCreditAmountForUsageKey)
|
||||
requireBase64PrimitiveValue(t, values[useAICreditsSentinelKey], buildPrimitiveBoolBinary(true))
|
||||
requireBase64PrimitiveValue(t, values[availableCreditsSentinelKey], buildPrimitiveInt32Binary(available))
|
||||
requireBase64PrimitiveValue(t, values[minimumCreditAmountForUsageKey], buildPrimitiveInt32Binary(minimum))
|
||||
}
|
||||
|
||||
// TestBuildInitialStateUpdate verifies the USS update wrapper
|
||||
func TestBuildInitialStateUpdate(t *testing.T) {
|
||||
topicData := buildEmptyTopic()
|
||||
update := buildInitialStateUpdate(topicData)
|
||||
// Should be a valid proto bytes field (field 1 = initial_state)
|
||||
require.True(t, len(update) >= 0) // empty topic is valid
|
||||
|
||||
topicData2 := buildUSSTopicWithOAuth("token", "refresh", time.Now().Add(1*time.Hour))
|
||||
update2 := buildInitialStateUpdate(topicData2)
|
||||
require.True(t, len(update2) > len(update), "non-empty topic should produce larger update")
|
||||
}
|
||||
|
||||
// TestPoolSetAccountTokenComplete verifies pool accepts full credential set
|
||||
func TestPoolSetAccountTokenComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
expiry := time.Now().Add(1 * time.Hour)
|
||||
pool.SetAccountToken("acc-1", "ya29.full-token", "1//full-refresh", expiry)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.tokens["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, "ya29.full-token", info.AccessToken)
|
||||
require.Equal(t, "1//full-refresh", info.RefreshToken)
|
||||
require.False(t, info.ExpiresAt.IsZero())
|
||||
require.WithinDuration(t, expiry, info.ExpiresAt, time.Second)
|
||||
}
|
||||
|
||||
func TestPoolSetAccountModelCreditsComplete(t *testing.T) {
|
||||
csrf := "pool-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
available := int32(77)
|
||||
minimum := int32(25)
|
||||
pool.SetAccountModelCredits("acc-1", true, &available, &minimum)
|
||||
|
||||
srv.mu.RLock()
|
||||
info := srv.credits["acc-1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.True(t, info.UseAICredits)
|
||||
require.NotNil(t, info.AvailableCredits)
|
||||
require.Equal(t, available, *info.AvailableCredits)
|
||||
require.NotNil(t, info.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, minimum, *info.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestUpstreamAdapterExtractsCredentials verifies internal LS headers are extracted and stripped.
|
||||
func TestUpstreamAdapterExtractsCredentials(t *testing.T) {
|
||||
// Create a mock upstream that records what it receives
|
||||
var receivedHeaders http.Header
|
||||
var mu sync.Mutex
|
||||
fallback := &recordingUpstreamWithCallback{}
|
||||
fallback.onDo = func(req *http.Request) {
|
||||
mu.Lock()
|
||||
receivedHeaders = req.Header.Clone()
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
csrf := "test-csrf"
|
||||
srv, err := NewMockExtensionServer(csrf)
|
||||
require.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
pool := &Pool{
|
||||
config: DefaultConfig(),
|
||||
instances: make(map[string][]*Instance),
|
||||
extServer: srv,
|
||||
}
|
||||
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
// Non-streamGenerateContent request → should pass through to fallback
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", nil)
|
||||
req.Header.Set("Authorization", "Bearer ya29.test")
|
||||
req.Header.Set("X-Antigravity-Refresh-Token", "1//secret-refresh")
|
||||
req.Header.Set("X-Antigravity-Token-Expiry", "2026-03-29T19:00:00Z")
|
||||
req.Header.Set(useAICreditsHeader, "true")
|
||||
req.Header.Set(availableCreditsHeader, "42")
|
||||
req.Header.Set(minimumCreditAmountHeader, "50")
|
||||
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Internal headers should never leak to the direct upstream.
|
||||
mu.Lock()
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Refresh-Token"))
|
||||
require.Empty(t, receivedHeaders.Get("X-Antigravity-Token-Expiry"))
|
||||
require.Empty(t, receivedHeaders.Get(useAICreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(availableCreditsHeader))
|
||||
require.Empty(t, receivedHeaders.Get(minimumCreditAmountHeader))
|
||||
mu.Unlock()
|
||||
|
||||
srv.mu.RLock()
|
||||
tokenInfo := srv.tokens["1"]
|
||||
creditsInfo := srv.credits["1"]
|
||||
srv.mu.RUnlock()
|
||||
|
||||
require.NotNil(t, tokenInfo)
|
||||
require.Equal(t, "ya29.test", tokenInfo.AccessToken)
|
||||
require.NotNil(t, creditsInfo)
|
||||
require.True(t, creditsInfo.UseAICredits)
|
||||
require.NotNil(t, creditsInfo.AvailableCredits)
|
||||
require.Equal(t, int32(42), *creditsInfo.AvailableCredits)
|
||||
require.NotNil(t, creditsInfo.MinimumCreditAmountForUsage)
|
||||
require.Equal(t, int32(50), *creditsInfo.MinimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
// TestExtractPromptAndModelMultiTurn verifies multi-turn prompt extraction
|
||||
func TestExtractPromptAndModelMultiTurn(t *testing.T) {
|
||||
body := `{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"request": {
|
||||
"systemInstruction": {"parts": [{"text": "You are helpful"}]},
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||
{"role": "user", "parts": [{"text": "How are you?"}]}
|
||||
]
|
||||
}
|
||||
}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "claude-sonnet-4-6", model)
|
||||
require.Contains(t, prompt, "You are helpful")
|
||||
require.Contains(t, prompt, "Hello")
|
||||
require.Contains(t, prompt, "Hi there!")
|
||||
require.Contains(t, prompt, "How are you?")
|
||||
}
|
||||
|
||||
// TestExtractUsageFromTrajectory verifies token usage extraction
|
||||
func TestExtractUsageFromTrajectory(t *testing.T) {
|
||||
resp := `{
|
||||
"trajectory": {
|
||||
"steps": [{
|
||||
"type": "CORTEX_STEP_TYPE_PLANNER_RESPONSE",
|
||||
"status": "CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse": {"response": "OK"},
|
||||
"metadata": {
|
||||
"modelUsage": {
|
||||
"inputTokens": "150",
|
||||
"outputTokens": "5"
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`
|
||||
usage := extractUsageFromTrajectory([]byte(resp))
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 150, usage["promptTokenCount"])
|
||||
require.Equal(t, 5, usage["candidatesTokenCount"])
|
||||
require.Equal(t, 155, usage["totalTokenCount"])
|
||||
}
|
||||
|
||||
// TestSSEChunkFormat verifies the Gemini SSE output format
|
||||
func TestSSEChunkFormat(t *testing.T) {
|
||||
chunk := buildGeminiSSEChunk("Hello world")
|
||||
require.True(t, len(chunk) > 0)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"text":"Hello world"`)
|
||||
require.Contains(t, chunk, `"role":"model"`)
|
||||
require.True(t, chunk[len(chunk)-2:] == "\n\n")
|
||||
|
||||
// Verify it's valid JSON after stripping "data: " prefix
|
||||
jsonStr := chunk[len("data: ") : len(chunk)-2]
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||
require.NoError(t, err)
|
||||
response := parsed["response"].(map[string]any)
|
||||
candidates := response["candidates"].([]any)
|
||||
require.Len(t, candidates, 1)
|
||||
}
|
||||
|
||||
// TestSSEFinalChunkFormat verifies the final SSE chunk with usage
|
||||
func TestSSEFinalChunkFormat(t *testing.T) {
|
||||
usage := map[string]any{
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150,
|
||||
}
|
||||
chunk := buildGeminiSSEFinalChunk(usage)
|
||||
require.Contains(t, chunk, "data: ")
|
||||
require.Contains(t, chunk, `"finishReason":"STOP"`)
|
||||
require.Contains(t, chunk, `"usageMetadata"`)
|
||||
}
|
||||
|
||||
func TestStreamCascadeResponsePollsImmediately(t *testing.T) {
|
||||
var getCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory") {
|
||||
getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, &recordingUpstream{})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
upstream.streamCascadeResponse(ctx, inst, "cid-1", pw, nil, nil)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(pr)
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
|
||||
require.GreaterOrEqual(t, getCalls.Load(), int32(1))
|
||||
require.Contains(t, string(body), "hello from ls")
|
||||
}
|
||||
|
||||
// TestRequestHasToolsEdgeCases verifies tool detection edge cases
|
||||
func TestRequestHasToolsEdgeCases(t *testing.T) {
|
||||
// null tools
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":null}`)))
|
||||
// tools with empty function declarations
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[]}]}`)))
|
||||
// deeply nested wrapped format
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"codeExecution":{}}]}}`)))
|
||||
}
|
||||
|
||||
func TestJSParityRouteReusesCascadeSession(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
var getCalls atomic.Int32
|
||||
var sendBodiesMu sync.Mutex
|
||||
var sendBodies []map[string]any
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
var payload map[string]any
|
||||
err := json.NewDecoder(r.Body).Decode(&payload)
|
||||
require.NoError(t, err)
|
||||
sendBodiesMu.Lock()
|
||||
sendBodies = append(sendBodies, payload)
|
||||
sendBodiesMu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
call := getCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
text := "hello from ls"
|
||||
if call > 1 {
|
||||
text = "follow up from ls"
|
||||
}
|
||||
_, _ = w.Write([]byte(fmt.Sprintf(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"%s"}}]}}`, text)))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, &recordingUpstream{})
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body2), `"text":"follow up from ls"`)
|
||||
|
||||
require.Equal(t, int32(1), startCalls.Load(), "cascade should be reused for append-only transcript")
|
||||
require.Equal(t, int32(2), sendCalls.Load())
|
||||
|
||||
sendBodiesMu.Lock()
|
||||
require.Len(t, sendBodies, 2)
|
||||
firstSend := sendBodies[0]
|
||||
sendBodiesMu.Unlock()
|
||||
|
||||
require.Equal(t, "cid-1", firstSend["cascadeId"])
|
||||
require.Equal(t, false, firstSend["blocking"])
|
||||
metadata, ok := firstSend["metadata"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "antigravity", metadata["ideName"])
|
||||
require.Equal(t, "1.107.0", metadata["ideVersion"])
|
||||
require.NotContains(t, firstSend, "clientType")
|
||||
require.NotContains(t, firstSend, "messageOrigin")
|
||||
cascadeConfig, ok := firstSend["cascadeConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
plannerConfig, ok := cascadeConfig["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
require.Len(t, cascadeConfig, 1)
|
||||
}
|
||||
|
||||
func TestJSParityRouteFallsBackOnSystemInstructionDrift(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-csrf", r.Header.Get("x-codeium-csrf-token"))
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/GetCascadeTrajectory"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE","plannerResponse":{"response":"hello from ls"}}]}}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
inst.SetModelMappingReady(true)
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
req1Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are helpful"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req1, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req1Body))
|
||||
require.NoError(t, err)
|
||||
req1.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp1, err := upstream.Do(req1, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body1, err := io.ReadAll(resp1.Body)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body1), `"text":"hello from ls"`)
|
||||
|
||||
req2Body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","systemInstruction":{"parts":[{"text":"You are different"}]},"contents":[{"role":"user","parts":[{"text":"hello"}]},{"role":"model","parts":[{"text":"hello from ls"}]},{"role":"user","parts":[{"text":"follow up"}]}]}}`)
|
||||
req2, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(req2Body))
|
||||
require.NoError(t, err)
|
||||
req2.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp2, err := upstream.Do(req2, "", 42, 1)
|
||||
require.NoError(t, err)
|
||||
body2, err := io.ReadAll(resp2.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "ok", string(body2))
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
require.Equal(t, int32(1), startCalls.Load())
|
||||
require.Equal(t, int32(1), sendCalls.Load())
|
||||
}
|
||||
|
||||
func TestJSParityRouteErrorsWhenModelMappingPending(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", LSStrategyJSParity)
|
||||
|
||||
var startCalls atomic.Int32
|
||||
var sendCalls atomic.Int32
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.HasSuffix(r.URL.Path, "/StartCascade"):
|
||||
startCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"cascadeId":"cid-1"}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/SendUserCascadeMessage"):
|
||||
sendCalls.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"queued":false}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: "42",
|
||||
CSRF: "test-csrf",
|
||||
Address: strings.TrimPrefix(server.URL, "https://"),
|
||||
client: server.Client(),
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
|
||||
fallback := &recordingUpstream{}
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 1},
|
||||
instances: map[string][]*Instance{"42": []*Instance{inst}},
|
||||
}
|
||||
upstream := NewLSPoolUpstream(pool, fallback)
|
||||
|
||||
reqBody := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"session-a","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1internal:streamGenerateContent?alt=sse", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer downstream-a")
|
||||
|
||||
resp, err := upstream.Do(req, "", 42, 1)
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, errLSModelMapPending)
|
||||
require.Equal(t, int32(0), startCalls.Load())
|
||||
require.Equal(t, int32(0), sendCalls.Load())
|
||||
require.Equal(t, 0, fallback.doCalls)
|
||||
}
|
||||
|
||||
// recordingUpstreamWithCallback extends the base recordingUpstream with a callback
|
||||
type recordingUpstreamWithCallback struct {
|
||||
recordingUpstream
|
||||
onDo func(req *http.Request)
|
||||
}
|
||||
|
||||
func (r *recordingUpstreamWithCallback) Do(req *http.Request, proxyURL string, accountID int64, c int) (*http.Response, error) {
|
||||
if r.onDo != nil {
|
||||
r.onDo(req)
|
||||
}
|
||||
return r.recordingUpstream.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
920
backend/internal/pkg/lspool/mock_extension_server.go
Normal file
920
backend/internal/pkg/lspool/mock_extension_server.go
Normal file
@ -0,0 +1,920 @@
|
||||
// Package lspool provides a mock Extension Server that the LS binary connects
|
||||
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
|
||||
// using connectNodeAdapter. We replicate that protocol here.
|
||||
//
|
||||
// Protocol details (from extension.js source):
|
||||
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
|
||||
// - Auth: x-codeium-csrf-token header on every request
|
||||
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
|
||||
// OR application/connect+proto (with 5-byte envelope)
|
||||
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
|
||||
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
|
||||
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
|
||||
//
|
||||
// The LS sends requests with content-type "application/connect+proto" for BOTH
|
||||
// unary and streaming RPCs. ConnectRPC's content-type regex:
|
||||
//
|
||||
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
|
||||
//
|
||||
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
|
||||
// However the LS Go client uses the Connect protocol client which always sends
|
||||
// "application/proto" for unary and "application/connect+proto" for streaming.
|
||||
//
|
||||
// We detect the RPC kind from the URL path and respond accordingly.
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Proto helpers — hand-encode minimal proto messages so we don't
|
||||
// need to import the full generated proto package.
|
||||
// ============================================================
|
||||
|
||||
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
|
||||
func encodeProtoString(fieldNum int, val string) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, []byte(val)...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
|
||||
func encodeProtoBytes(fieldNum int, val []byte) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
||||
length := encodeVarint(uint64(len(val)))
|
||||
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
||||
out = append(out, tag...)
|
||||
out = append(out, length...)
|
||||
out = append(out, val...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoVarint writes a proto varint field (wire type 0).
|
||||
func encodeProtoVarint(fieldNum int, val uint64) []byte {
|
||||
tag := encodeVarint(uint64(fieldNum<<3 | 0))
|
||||
v := encodeVarint(val)
|
||||
out := make([]byte, 0, len(tag)+len(v))
|
||||
out = append(out, tag...)
|
||||
out = append(out, v...)
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeProtoBool writes a proto bool field.
|
||||
func encodeProtoBool(fieldNum int, val bool) []byte {
|
||||
v := uint64(0)
|
||||
if val {
|
||||
v = 1
|
||||
}
|
||||
return encodeProtoVarint(fieldNum, v)
|
||||
}
|
||||
|
||||
func encodeVarint(v uint64) []byte {
|
||||
buf := make([]byte, binary.MaxVarintLen64)
|
||||
n := binary.PutUvarint(buf, v)
|
||||
return buf[:n]
|
||||
}
|
||||
|
||||
// decodeProtoString extracts a string field from raw proto bytes.
|
||||
func decodeProtoString(data []byte, targetField int) string {
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
if i >= len(data) {
|
||||
break
|
||||
}
|
||||
tag, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
i += n
|
||||
fieldNum := int(tag >> 3)
|
||||
wireType := tag & 0x7
|
||||
|
||||
switch wireType {
|
||||
case 0: // varint
|
||||
_, n = binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
case 2: // length-delimited
|
||||
length, n := binary.Uvarint(data[i:])
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
i += n
|
||||
if fieldNum == targetField {
|
||||
end := i + int(length)
|
||||
if end > len(data) {
|
||||
return ""
|
||||
}
|
||||
return string(data[i:end])
|
||||
}
|
||||
i += int(length)
|
||||
case 1: // 64-bit
|
||||
i += 8
|
||||
case 5: // 32-bit
|
||||
i += 4
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// ConnectRPC envelope helpers
|
||||
// ============================================================
|
||||
|
||||
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
|
||||
// 1 byte flags + 4 byte big-endian length + payload
|
||||
func connectEnvelope(flags byte, payload []byte) []byte {
|
||||
frame := make([]byte, 5+len(payload))
|
||||
frame[0] = flags
|
||||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
||||
copy(frame[5:], payload)
|
||||
return frame
|
||||
}
|
||||
|
||||
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
|
||||
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
|
||||
func connectEndOfStream() []byte {
|
||||
trailer := []byte("{}")
|
||||
return connectEnvelope(0x02, trailer)
|
||||
}
|
||||
|
||||
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
|
||||
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
|
||||
func unwrapConnectEnvelope(body []byte) []byte {
|
||||
if len(body) < 5 {
|
||||
return body
|
||||
}
|
||||
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
|
||||
if body[0] > 0x02 {
|
||||
return body // Not envelope-framed, return raw
|
||||
}
|
||||
plen := binary.BigEndian.Uint32(body[1:5])
|
||||
if int(plen)+5 > len(body) {
|
||||
return body // Length mismatch, return raw
|
||||
}
|
||||
return body[5 : 5+plen]
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// OAuthTokenInfo proto builder
|
||||
// ============================================================
|
||||
|
||||
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
|
||||
//
|
||||
// message OAuthTokenInfo {
|
||||
// string access_token = 1;
|
||||
// string token_type = 2;
|
||||
// string refresh_token = 3;
|
||||
// google.protobuf.Timestamp expiry = 4;
|
||||
// bool is_gcp_tos = 6;
|
||||
// }
|
||||
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
var buf []byte
|
||||
buf = append(buf, encodeProtoString(1, accessToken)...)
|
||||
buf = append(buf, encodeProtoString(2, "Bearer")...)
|
||||
if refreshToken != "" {
|
||||
buf = append(buf, encodeProtoString(3, refreshToken)...)
|
||||
}
|
||||
// Use real expiry if provided, otherwise default to 1 hour from now
|
||||
expiry := expiresAt
|
||||
if expiry.IsZero() {
|
||||
expiry = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
ts := ×tamppb.Timestamp{
|
||||
Seconds: expiry.Unix(),
|
||||
}
|
||||
tsBytes, _ := proto.Marshal(ts)
|
||||
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
|
||||
buf = append(buf, encodeProtoBool(6, true)...)
|
||||
return buf
|
||||
}
|
||||
|
||||
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
|
||||
//
|
||||
// message Topic { map<string, Row> data = 1; }
|
||||
// message Row { string value = 1; int64 e_tag = 2; }
|
||||
//
|
||||
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
|
||||
// base64(toBinary(OAuthTokenInfo)).
|
||||
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
||||
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
|
||||
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, tokenB64)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
|
||||
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
|
||||
// Topic: data map entries use field 1
|
||||
var topic []byte
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
|
||||
return topic
|
||||
}
|
||||
|
||||
func buildPrimitiveBoolBinary(val bool) []byte {
|
||||
// Primitive.bool_value is field 13 in the proto definition
|
||||
return encodeProtoBool(13, val)
|
||||
}
|
||||
|
||||
func buildPrimitiveInt32Binary(val int32) []byte {
|
||||
// Primitive.int32_value is field 3 in the proto definition
|
||||
return encodeProtoVarint(3, uint64(uint32(val)))
|
||||
}
|
||||
|
||||
func encodeUSSBinaryValue(value []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(value)
|
||||
}
|
||||
|
||||
func encodeUSSPrimitiveBoolValue(val bool) string {
|
||||
return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val))
|
||||
}
|
||||
|
||||
func encodeUSSPrimitiveInt32Value(val int32) string {
|
||||
return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val))
|
||||
}
|
||||
|
||||
func buildUSSTopicRow(key string, value string) []byte {
|
||||
row := buildUSSRowBinary(value)
|
||||
|
||||
var entry []byte
|
||||
entry = append(entry, encodeProtoString(1, key)...)
|
||||
entry = append(entry, encodeProtoBytes(2, row)...)
|
||||
return entry
|
||||
}
|
||||
|
||||
func buildUSSRowBinary(value string) []byte {
|
||||
var row []byte
|
||||
row = append(row, encodeProtoString(1, value)...)
|
||||
row = append(row, encodeProtoVarint(2, 1)...)
|
||||
return row
|
||||
}
|
||||
|
||||
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
entries := make([][]byte, 0, 3)
|
||||
entries = append(entries, buildUSSTopicRow(
|
||||
useAICreditsSentinelKey,
|
||||
encodeUSSPrimitiveBoolValue(info.UseAICredits),
|
||||
))
|
||||
// JS protocol: useAICreditsSentinelKey carries the toggle state.
|
||||
// availableCreditsSentinelKey is only present when credits are enabled.
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits)))
|
||||
}
|
||||
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum)))
|
||||
|
||||
var topic []byte
|
||||
for _, entry := range entries {
|
||||
topic = append(topic, encodeProtoBytes(1, entry)...)
|
||||
}
|
||||
return topic
|
||||
}
|
||||
|
||||
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
|
||||
func buildEmptyTopic() []byte {
|
||||
return []byte{} // Empty message = no map entries
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// UnifiedStateSyncUpdate builder
|
||||
// ============================================================
|
||||
|
||||
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
|
||||
//
|
||||
// message UnifiedStateSyncUpdate {
|
||||
// oneof update_type {
|
||||
// Topic initial_state = 1;
|
||||
// AppliedUpdate applied_update = 2;
|
||||
// }
|
||||
// }
|
||||
func buildInitialStateUpdate(topicData []byte) []byte {
|
||||
return encodeProtoBytes(1, topicData)
|
||||
}
|
||||
|
||||
func buildAppliedUpdate(key string, row []byte) []byte {
|
||||
var applied []byte
|
||||
applied = append(applied, encodeProtoString(1, key)...)
|
||||
if len(row) > 0 {
|
||||
applied = append(applied, encodeProtoBytes(2, row)...)
|
||||
}
|
||||
return encodeProtoBytes(2, applied)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MockExtensionServer
|
||||
// ============================================================
|
||||
|
||||
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
|
||||
// Language Server binary connects to. It implements just enough of the
|
||||
// ExtensionServerService to keep the LS operational.
|
||||
type MockExtensionServer struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
port int
|
||||
csrf string
|
||||
mu sync.RWMutex
|
||||
tokens map[string]*TokenInfo // account_id -> token info
|
||||
credits map[string]*ModelCreditsInfo // account_id -> model credits info
|
||||
subscribers map[string]map[int]*stateSubscriber
|
||||
nextSubID int
|
||||
lastAccountID string
|
||||
logger *slog.Logger
|
||||
|
||||
// Trajectory callback — when LS pushes trajectory updates, we forward them
|
||||
onTrajectoryUpdate func(topic, key string, data []byte)
|
||||
}
|
||||
|
||||
// TokenInfo holds OAuth token details for an account.
|
||||
type TokenInfo struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
|
||||
}
|
||||
|
||||
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
|
||||
type ModelCreditsInfo struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmountForUsage *int32
|
||||
}
|
||||
|
||||
type stateSubscriber struct {
|
||||
id int
|
||||
accountID string
|
||||
topic string
|
||||
updates chan []byte
|
||||
}
|
||||
|
||||
const (
|
||||
useAICreditsSentinelKey = "useAICreditsSentinelKey"
|
||||
availableCreditsSentinelKey = "availableCreditsSentinelKey"
|
||||
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
|
||||
defaultMinimumCreditAmountForUsage = int32(50)
|
||||
)
|
||||
|
||||
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
|
||||
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
m := &MockExtensionServer{
|
||||
listener: listener,
|
||||
port: listener.Addr().(*net.TCPAddr).Port,
|
||||
csrf: csrf,
|
||||
tokens: make(map[string]*TokenInfo),
|
||||
credits: make(map[string]*ModelCreditsInfo),
|
||||
subscribers: make(map[string]map[int]*stateSubscriber),
|
||||
logger: slog.Default().With("component", "mock-ext-server"),
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
extService := "/exa.extension_server_pb.ExtensionServerService/"
|
||||
|
||||
// Register all RPCs the LS calls on the Extension Server.
|
||||
// Unary RPCs — return application/proto
|
||||
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
|
||||
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
|
||||
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
|
||||
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
|
||||
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
|
||||
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
|
||||
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
|
||||
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
|
||||
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
|
||||
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
|
||||
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
|
||||
|
||||
// Server-streaming RPCs — return application/connect+proto
|
||||
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
|
||||
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
|
||||
|
||||
// Catch-all for any unregistered RPCs
|
||||
mux.HandleFunc("/", m.handleCatchAll)
|
||||
|
||||
m.server = &http.Server{Handler: mux}
|
||||
|
||||
go func() {
|
||||
if err := m.server.Serve(listener); err != http.ErrServerClosed {
|
||||
m.logger.Error("extension server error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Port returns the listening port.
|
||||
func (m *MockExtensionServer) Port() int {
|
||||
return m.port
|
||||
}
|
||||
|
||||
// SetToken sets the OAuth token for an account.
|
||||
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
|
||||
m.mu.Lock()
|
||||
m.tokens[accountID] = info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
if info == nil {
|
||||
return
|
||||
}
|
||||
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
|
||||
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
||||
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
|
||||
}
|
||||
|
||||
// SetModelCredits sets the uss-modelCredits state for an account.
|
||||
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
copyInfo := *info
|
||||
m.mu.Lock()
|
||||
m.credits[accountID] = ©Info
|
||||
m.lastAccountID = accountID
|
||||
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...)
|
||||
}
|
||||
|
||||
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
|
||||
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
|
||||
m.onTrajectoryUpdate = fn
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.tokens[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.tokens {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
|
||||
if m.lastAccountID != "" {
|
||||
if info := m.credits[m.lastAccountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
for _, info := range m.credits {
|
||||
return info
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
|
||||
if accountID != "" {
|
||||
if info := m.tokens[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentTokenLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
|
||||
if accountID != "" {
|
||||
if info := m.credits[accountID]; info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return m.currentModelCreditsLocked()
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
|
||||
topicSubs := m.subscribers[topic]
|
||||
if len(topicSubs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*stateSubscriber, 0, len(topicSubs))
|
||||
for _, sub := range topicSubs {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
|
||||
continue
|
||||
}
|
||||
out = append(out, sub)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
|
||||
for _, sub := range subscribers {
|
||||
if sub == nil {
|
||||
continue
|
||||
}
|
||||
for _, update := range updates {
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
payload := append([]byte(nil), update...)
|
||||
select {
|
||||
case sub.updates <- payload:
|
||||
default:
|
||||
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
|
||||
if info == nil {
|
||||
info = &ModelCreditsInfo{}
|
||||
}
|
||||
minimum := defaultMinimumCreditAmountForUsage
|
||||
if info.MinimumCreditAmountForUsage != nil {
|
||||
minimum = *info.MinimumCreditAmountForUsage
|
||||
}
|
||||
|
||||
updates := make([][]byte, 0, 3)
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
useAICreditsSentinelKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)),
|
||||
))
|
||||
|
||||
if info.UseAICredits {
|
||||
credits := int32(9999)
|
||||
if info.AvailableCredits != nil {
|
||||
credits = *info.AvailableCredits
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
availableCreditsSentinelKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)),
|
||||
))
|
||||
} else {
|
||||
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
|
||||
}
|
||||
updates = append(updates, buildAppliedUpdate(
|
||||
minimumCreditAmountForUsageKey,
|
||||
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)),
|
||||
))
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// Close shuts down the server.
|
||||
func (m *MockExtensionServer) Close() error {
|
||||
return m.server.Close()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Middleware
|
||||
// ============================================================
|
||||
|
||||
type unaryHandler func(body []byte) []byte
|
||||
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
|
||||
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// CSRF check
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
return
|
||||
}
|
||||
|
||||
// The LS might send with envelope framing (application/connect+proto)
|
||||
// or without (application/proto). Detect and unwrap.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
protoBody := body
|
||||
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
|
||||
protoBody = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
|
||||
|
||||
responseProto := handler(protoBody)
|
||||
|
||||
// Respond with proper unary ConnectRPC content-type.
|
||||
// If the request used "connect+proto", the response should be "application/proto"
|
||||
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
w.WriteHeader(200)
|
||||
if len(responseProto) > 0 {
|
||||
w.Write(responseProto)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
|
||||
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
||||
return
|
||||
}
|
||||
|
||||
// Unwrap envelope from request
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
|
||||
body = unwrapConnectEnvelope(body)
|
||||
}
|
||||
|
||||
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
|
||||
|
||||
// Set streaming response content-type
|
||||
w.Header().Set("Content-Type", "application/connect+proto")
|
||||
w.WriteHeader(200)
|
||||
|
||||
handler(body, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
|
||||
token := r.Header.Get("x-codeium-csrf-token")
|
||||
if m.csrf != "" && token != m.csrf {
|
||||
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(403)
|
||||
w.Write([]byte("Invalid CSRF token"))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Unary RPC Handlers — each receives raw proto request body,
|
||||
// returns raw proto response body.
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
|
||||
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
|
||||
// We just log the ports — they're informational.
|
||||
m.logger.Info("LanguageServerStarted",
|
||||
"body_len", len(body))
|
||||
// Return empty LanguageServerStartedResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
|
||||
// Return empty HeartbeatResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
|
||||
// GetSecretValueRequest: key = field 1
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("GetSecretValue", "key", key)
|
||||
|
||||
m.mu.RLock()
|
||||
var token string
|
||||
if info := m.currentTokenLocked(); info != nil {
|
||||
token = info.AccessToken
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// GetSecretValueResponse: value = field 1
|
||||
if token != "" {
|
||||
return encodeProtoString(1, token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
|
||||
key := decodeProtoString(body, 1)
|
||||
m.logger.Debug("StoreSecretValue", "key", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
|
||||
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
|
||||
return encodeProtoBool(1, false)
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
|
||||
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
|
||||
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
|
||||
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
|
||||
|
||||
// Extract topic name from the embedded UpdateRequest
|
||||
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
|
||||
// We need to dig into the nested message to get topic_name
|
||||
if m.onTrajectoryUpdate != nil {
|
||||
// For now, just notify that an update was pushed
|
||||
m.onTrajectoryUpdate("", "", body)
|
||||
}
|
||||
|
||||
// Return empty PushUnifiedStateSyncUpdateResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
|
||||
m.logger.Debug("RecordError", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
|
||||
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onDefault(body []byte) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Streaming RPC Handlers
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
|
||||
topic := decodeProtoString(body, 1)
|
||||
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
m.logger.Error("ResponseWriter does not support Flush")
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
accountID := m.lastAccountID
|
||||
subID := m.nextSubID
|
||||
m.nextSubID++
|
||||
sub := &stateSubscriber{
|
||||
id: subID,
|
||||
accountID: accountID,
|
||||
topic: topic,
|
||||
updates: make(chan []byte, 16),
|
||||
}
|
||||
if m.subscribers[topic] == nil {
|
||||
m.subscribers[topic] = make(map[int]*stateSubscriber)
|
||||
}
|
||||
m.subscribers[topic][subID] = sub
|
||||
|
||||
// Build initial state based on topic
|
||||
var topicData []byte
|
||||
switch topic {
|
||||
case "uss-oauth":
|
||||
tokenInfo := m.tokenForAccountLocked(accountID)
|
||||
if tokenInfo != nil {
|
||||
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
case "uss-modelCredits":
|
||||
creditsInfo := m.creditsForAccountLocked(accountID)
|
||||
if creditsInfo != nil {
|
||||
topicData = buildUSSTopicWithModelCredits(creditsInfo)
|
||||
} else {
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
default:
|
||||
// For all other topics (browserPreferences, enterprisePreferences, etc.),
|
||||
// return empty topic data.
|
||||
topicData = buildEmptyTopic()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
if topicSubs := m.subscribers[topic]; topicSubs != nil {
|
||||
delete(topicSubs, subID)
|
||||
if len(topicSubs) == 0 {
|
||||
delete(m.subscribers, topic)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Send initial state as envelope-framed message
|
||||
initialUpdate := buildInitialStateUpdate(topicData)
|
||||
frame := connectEnvelope(0x00, initialUpdate)
|
||||
w.Write(frame)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
|
||||
return
|
||||
case update := <-sub.updates:
|
||||
if len(update) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
|
||||
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
|
||||
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
|
||||
// Send end-of-stream immediately — we don't execute commands
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.Write(connectEndOfStream())
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Catch-all handler
|
||||
// ============================================================
|
||||
|
||||
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
|
||||
if !m.checkCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
|
||||
|
||||
// Drain request body
|
||||
io.ReadAll(r.Body)
|
||||
|
||||
// Determine if this is likely a unary or streaming request based on content-type.
|
||||
ct := r.Header.Get("Content-Type")
|
||||
if strings.Contains(ct, "connect+") {
|
||||
// Could be streaming — respond with unary proto to be safe
|
||||
// (unary Connect requests can also use connect+ prefix in some client impls)
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/proto")
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
1186
backend/internal/pkg/lspool/pool.go
Normal file
1186
backend/internal/pkg/lspool/pool.go
Normal file
File diff suppressed because it is too large
Load Diff
376
backend/internal/pkg/lspool/pool_test.go
Normal file
376
backend/internal/pkg/lspool/pool_test.go
Normal file
@ -0,0 +1,376 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildLSEnvKeepsExistingSSLValues(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"SSL_CERT_FILE=/custom/ca.pem",
|
||||
"SSL_CERT_DIR=/custom/certs",
|
||||
}, "/opt/antigravity", "")
|
||||
require.Contains(t, env, "ANTIGRAVITY_EDITOR_APP_ROOT=/opt/antigravity")
|
||||
require.Contains(t, env, "SSL_CERT_FILE=/custom/ca.pem")
|
||||
require.Contains(t, env, "SSL_CERT_DIR=/custom/certs")
|
||||
}
|
||||
|
||||
func TestBuildLSEnvClearsInheritedProxyWhenUnset(t *testing.T) {
|
||||
env := buildLSEnv([]string{
|
||||
"HTTPS_PROXY=http://old-proxy:8080",
|
||||
"HTTP_PROXY=http://old-proxy:8080",
|
||||
"ALL_PROXY=socks5://old-proxy:1080",
|
||||
"https_proxy=http://old-proxy:8080",
|
||||
"http_proxy=http://old-proxy:8080",
|
||||
"all_proxy=socks5://old-proxy:1080",
|
||||
}, "/opt/antigravity", "")
|
||||
|
||||
require.Contains(t, env, "HTTPS_PROXY=")
|
||||
require.Contains(t, env, "HTTP_PROXY=")
|
||||
require.Contains(t, env, "ALL_PROXY=")
|
||||
require.Contains(t, env, "https_proxy=")
|
||||
require.Contains(t, env, "http_proxy=")
|
||||
require.Contains(t, env, "all_proxy=")
|
||||
}
|
||||
|
||||
func TestShortAccountID(t *testing.T) {
|
||||
require.Equal(t, "9", shortAccountID("9"))
|
||||
require.Equal(t, "12345678", shortAccountID("12345678"))
|
||||
require.Equal(t, "12345678", shortAccountID("123456789"))
|
||||
}
|
||||
|
||||
func TestFrameConnectMessage(t *testing.T) {
|
||||
framed := frameConnectMessage([]byte(`{"x":1}`))
|
||||
require.Len(t, framed, 5+len(`{"x":1}`))
|
||||
require.Equal(t, byte(0), framed[0])
|
||||
require.Equal(t, uint32(len(`{"x":1}`)), binary.BigEndian.Uint32(framed[1:5]))
|
||||
require.Equal(t, `{"x":1}`, string(framed[5:]))
|
||||
}
|
||||
|
||||
func TestConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("hello")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
require.Len(t, env, 5+len(payload))
|
||||
require.Equal(t, byte(0x00), env[0])
|
||||
require.Equal(t, uint32(5), binary.BigEndian.Uint32(env[1:5]))
|
||||
require.Equal(t, "hello", string(env[5:]))
|
||||
}
|
||||
|
||||
func TestUnwrapConnectEnvelope(t *testing.T) {
|
||||
payload := []byte("test data")
|
||||
env := connectEnvelope(0x00, payload)
|
||||
unwrapped := unwrapConnectEnvelope(env)
|
||||
require.Equal(t, payload, unwrapped)
|
||||
short := []byte{1, 2}
|
||||
require.Equal(t, short, unwrapConnectEnvelope(short))
|
||||
}
|
||||
|
||||
func TestExtractPromptAndModel(t *testing.T) {
|
||||
body := `{"model":"gemini-2.5-pro","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hello world"}]}]}}`
|
||||
prompt, model := extractPromptAndModel([]byte(body))
|
||||
require.Equal(t, "hello world", prompt)
|
||||
require.Equal(t, "gemini-2.5-pro", model)
|
||||
|
||||
body2 := `{"contents":[{"role":"user","parts":[{"text":"test prompt"}]}]}`
|
||||
prompt2, _ := extractPromptAndModel([]byte(body2))
|
||||
require.Equal(t, "test prompt", prompt2)
|
||||
}
|
||||
|
||||
func TestResolveModelEnum(t *testing.T) {
|
||||
// Without dynamic mapping loaded, should return fallback (312 = gemini-2.5-flash)
|
||||
require.True(t, resolveModelEnum("gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("models/gemini-2.5-flash") > 0)
|
||||
require.True(t, resolveModelEnum("claude-sonnet-4-6") > 0)
|
||||
require.True(t, resolveModelEnum("unknown-model") > 0)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("models/gemini-2.5-flash")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestBuildCascadeConfigClaudeIncludesRequestedModel(t *testing.T) {
|
||||
cfg := buildCascadeConfig("claude-sonnet-4-6")
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
plannerConfig, ok := cfg["plannerConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
requestedModel, ok := plannerConfig["requestedModel"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, requestedModel["model"])
|
||||
require.Len(t, plannerConfig, 1)
|
||||
}
|
||||
|
||||
func TestDoNonStreamGeneratePassesThrough(t *testing.T) {
|
||||
fallback := &recordingUpstream{}
|
||||
upstream := NewLSPoolUpstream(&Pool{}, fallback)
|
||||
req, _ := http.NewRequest("POST", "https://example.com/v1beta/models/gemini:generateContent", bytes.NewReader([]byte(`{}`)))
|
||||
resp, err := upstream.Do(req, "", 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, 1, fallback.doCalls)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseText(t *testing.T) {
|
||||
resp := `{"status":"CASCADE_RUN_STATUS_IDLE","trajectory":{"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"},
|
||||
{"type":"CORTEX_STEP_TYPE_PLANNER_RESPONSE","status":"CORTEX_STEP_STATUS_DONE",
|
||||
"plannerResponse":{"response":"Hello world"}}
|
||||
]}}`
|
||||
text, generating, status := extractPlannerResponseText([]byte(resp))
|
||||
require.Equal(t, "Hello world", text)
|
||||
require.False(t, generating)
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", status)
|
||||
}
|
||||
|
||||
func TestExtractPlannerResponseState_ErrorDetails(t *testing.T) {
|
||||
resp := `{
|
||||
"status":"CASCADE_RUN_STATUS_IDLE",
|
||||
"trajectory":{
|
||||
"steps":[
|
||||
{"type":"CORTEX_STEP_TYPE_USER_INPUT","status":"CORTEX_STEP_STATUS_DONE"}
|
||||
],
|
||||
"executorMetadata":{
|
||||
"terminationReason":"ERROR",
|
||||
"errorDetails":{
|
||||
"errorCode":429,
|
||||
"shortError":"Model quota reached",
|
||||
"details":"You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
state := extractPlannerResponseState([]byte(resp))
|
||||
require.Equal(t, "CASCADE_RUN_STATUS_IDLE", state.Status)
|
||||
require.False(t, state.Generating)
|
||||
require.Empty(t, state.Text)
|
||||
require.Contains(t, state.ErrorMessage, "Model quota reached")
|
||||
require.Contains(t, state.ErrorMessage, "quota will reset after")
|
||||
}
|
||||
|
||||
func TestBuildGeminiSSEChunk(t *testing.T) {
|
||||
sse := buildGeminiSSEChunk("hello")
|
||||
require.Contains(t, sse, "data: ")
|
||||
require.Contains(t, sse, `"text":"hello"`)
|
||||
require.Contains(t, sse, `"role":"model"`)
|
||||
require.True(t, strings.HasSuffix(sse, "\n\n"))
|
||||
}
|
||||
|
||||
func TestRequestHasTools(t *testing.T) {
|
||||
// Wrapped format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[],"tools":[{"functionDeclarations":[{"name":"get_weather"}]}]}}`)))
|
||||
|
||||
// Direct format with tools
|
||||
require.True(t, requestHasTools([]byte(`{"contents":[],"tools":[{"functionDeclarations":[{"name":"f"}]}]}`)))
|
||||
|
||||
// No tools
|
||||
require.False(t, requestHasTools([]byte(`{"model":"m","project":"p","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`)))
|
||||
|
||||
// Empty tools array
|
||||
require.False(t, requestHasTools([]byte(`{"contents":[],"tools":[]}`)))
|
||||
}
|
||||
|
||||
func TestCurrentLSStrategy(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "js-parity")
|
||||
require.Equal(t, LSStrategyJSParity, CurrentLSStrategy())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_STRATEGY", "unknown")
|
||||
require.Equal(t, LSStrategyDirect, CurrentLSStrategy())
|
||||
}
|
||||
|
||||
func TestIsPermanentModelMappingError(t *testing.T) {
|
||||
require.True(t, isPermanentModelMappingError(errors.New(`oauth2: "unauthorized_client" "Unauthorized"`)))
|
||||
require.False(t, isPermanentModelMappingError(errors.New("context deadline exceeded")))
|
||||
}
|
||||
|
||||
func TestPoolSetAccountTokenClearsModelMappingUnavailable(t *testing.T) {
|
||||
pool := &Pool{
|
||||
instances: map[string][]*Instance{
|
||||
"9": {
|
||||
{AccountID: "9", Replica: 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
inst := pool.instances["9"][0]
|
||||
inst.SetModelMappingReady(true)
|
||||
inst.SetModelMappingUnavailable(`oauth2: "unauthorized_client" "Unauthorized"`)
|
||||
|
||||
pool.SetAccountToken("9", "ya29.new", "refresh", time.Now().Add(time.Hour))
|
||||
|
||||
require.False(t, inst.HasModelMappingReady())
|
||||
require.False(t, inst.HasModelMappingUnavailable())
|
||||
require.Empty(t, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
|
||||
func TestShouldFallbackDirectForModelMappingUnavailable(t *testing.T) {
|
||||
require.True(t, shouldFallbackDirect(fmt.Errorf("%w: oauth2 unauthorized_client", errLSModelMapDenied)))
|
||||
require.False(t, shouldFallbackDirect(errLSModelMapPending))
|
||||
}
|
||||
|
||||
func TestParseLSReplicaCountDefaultAndEnv(t *testing.T) {
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "3")
|
||||
require.Equal(t, 3, parseLSReplicaCount())
|
||||
|
||||
t.Setenv("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT", "0")
|
||||
require.Equal(t, 5, parseLSReplicaCount())
|
||||
}
|
||||
|
||||
func TestPoolGetUsesStickyReplicaSlot(t *testing.T) {
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {
|
||||
{AccountID: "acc-1", Replica: 0, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 1, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 2, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 3, healthy: true},
|
||||
{AccountID: "acc-1", Replica: 4, healthy: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routingKey := "acc-1:user-a:session-1"
|
||||
slot := replicaSlotIndex(routingKey, pool.replicaCount())
|
||||
inst := pool.Get("acc-1", routingKey)
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, slot, inst.Replica)
|
||||
}
|
||||
|
||||
func TestPoolGetWithoutRoutingKeyPrefersLeastBusyReplica(t *testing.T) {
|
||||
busy := &Instance{AccountID: "acc-1", Replica: 0, healthy: true}
|
||||
atomic.StoreInt64(&busy.inflight, 4)
|
||||
idle := &Instance{AccountID: "acc-1", Replica: 1, healthy: true}
|
||||
atomic.StoreInt64(&idle.inflight, 1)
|
||||
|
||||
pool := &Pool{
|
||||
config: Config{ReplicasPerAccount: 5},
|
||||
instances: map[string][]*Instance{
|
||||
"acc-1": {busy, idle},
|
||||
},
|
||||
}
|
||||
|
||||
inst := pool.Get("acc-1", "")
|
||||
require.NotNil(t, inst)
|
||||
require.Equal(t, 1, inst.Replica)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyProbesImmediately(t *testing.T) {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
attempts, err := waitForInstanceReady(ctx, 200*time.Millisecond, func(context.Context) error {
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, attempts)
|
||||
require.Less(t, time.Since(startedAt), 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWaitForInstanceReadyRetriesUntilSuccess(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
calls := 0
|
||||
attempts, err := waitForInstanceReady(ctx, 10*time.Millisecond, func(context.Context) error {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return errors.New("not ready")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, attempts)
|
||||
require.Equal(t, 3, calls)
|
||||
}
|
||||
|
||||
func TestDecideJSParityRoute(t *testing.T) {
|
||||
body := []byte(`{"model":"gemini-2.5-flash","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsed, err := parseGeminiRequest(body)
|
||||
require.NoError(t, err)
|
||||
decision := decideJSParityRoute(parsed, body)
|
||||
require.True(t, decision.UseLS)
|
||||
|
||||
imageBody := []byte(`{"model":"gemini-2.5-flash-image","request":{"sessionId":"s1","contents":[{"role":"user","parts":[{"text":"draw"}]}],"generationConfig":{"responseModalities":["TEXT","IMAGE"]}}}`)
|
||||
parsedImage, err := parseGeminiRequest(imageBody)
|
||||
require.NoError(t, err)
|
||||
decisionImage := decideJSParityRoute(parsedImage, imageBody)
|
||||
require.False(t, decisionImage.UseLS)
|
||||
require.Contains(t, strings.ToLower(decisionImage.Reason), "image")
|
||||
|
||||
noSessionBody := []byte(`{"model":"gemini-2.5-flash","request":{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}}`)
|
||||
parsedNoSession, err := parseGeminiRequest(noSessionBody)
|
||||
require.NoError(t, err)
|
||||
require.False(t, decideJSParityRoute(parsedNoSession, noSessionBody).UseLS)
|
||||
}
|
||||
|
||||
func TestUserNamespacePrefersExplicitHeader(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(userNamespaceHeader, "tenant-a")
|
||||
req.Header.Set("Authorization", "Bearer oauth-token")
|
||||
|
||||
nsWithExplicit := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithExplicit)
|
||||
|
||||
req.Header.Del(userNamespaceHeader)
|
||||
nsWithAuth := userNamespace(req)
|
||||
require.NotEqual(t, "anon", nsWithAuth)
|
||||
require.NotEqual(t, nsWithExplicit, nsWithAuth)
|
||||
}
|
||||
|
||||
func TestConversationPrefixEqual(t *testing.T) {
|
||||
prefix := []geminiConversationTurn{
|
||||
{Role: "user", Parts: []geminiConversationPart{{Kind: "text", Text: "hello"}}},
|
||||
{Role: "model", Parts: []geminiConversationPart{{Kind: "text", Text: "world"}}},
|
||||
}
|
||||
full := append(cloneConversationTurns(prefix), geminiConversationTurn{
|
||||
Role: "user",
|
||||
Parts: []geminiConversationPart{{Kind: "text", Text: "follow up"}},
|
||||
})
|
||||
require.True(t, conversationPrefixEqual(full, prefix))
|
||||
require.False(t, conversationPrefixEqual(prefix, full))
|
||||
}
|
||||
|
||||
func TestSystemTextCompatible(t *testing.T) {
|
||||
require.True(t, systemTextCompatible("You are helpful", ""))
|
||||
require.True(t, systemTextCompatible("You are helpful", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("", "You are helpful"))
|
||||
require.False(t, systemTextCompatible("You are helpful", "You are different"))
|
||||
}
|
||||
|
||||
type recordingUpstream struct {
|
||||
doCalls int
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
r.doCalls++
|
||||
return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewBufferString("ok")), Header: make(http.Header), Request: req}, nil
|
||||
}
|
||||
|
||||
func (r *recordingUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, c int, _ *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return r.Do(req, proxyURL, accountID, c)
|
||||
}
|
||||
268
backend/internal/pkg/lspool/proxy_bridge.go
Normal file
268
backend/internal/pkg/lspool/proxy_bridge.go
Normal file
@ -0,0 +1,268 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type lsProxyBridge struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
url string
|
||||
upstream string
|
||||
}
|
||||
|
||||
type lsProxyBridgeManager struct {
|
||||
mu sync.Mutex
|
||||
bridges map[string]*lsProxyBridge
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
var globalLSProxyBridgeManager = &lsProxyBridgeManager{
|
||||
bridges: make(map[string]*lsProxyBridge),
|
||||
logger: slog.Default().With("component", "lspool-proxy-bridge"),
|
||||
}
|
||||
|
||||
var (
|
||||
lsProxyBridgeDialTimeout = 10 * time.Second
|
||||
lsProxyBridgeProbeTargets = []string{
|
||||
"cloudcode-pa.googleapis.com:443",
|
||||
"oauthaccountmanager.googleapis.com:443",
|
||||
}
|
||||
)
|
||||
|
||||
func prepareLSProxyURL(raw string) (string, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
return normalized, nil
|
||||
case "socks5", "socks5h":
|
||||
return globalLSProxyBridgeManager.ensure(normalized, parsed)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) ensure(key string, upstream *url.URL) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if bridge := m.bridges[key]; bridge != nil {
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
bridge, err := newLSProxyBridge(upstream, m.logger)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
m.bridges[key] = bridge
|
||||
return bridge.url, nil
|
||||
}
|
||||
|
||||
func (m *lsProxyBridgeManager) closeAll() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for key, bridge := range m.bridges {
|
||||
if bridge != nil {
|
||||
_ = bridge.server.Close()
|
||||
_ = bridge.listener.Close()
|
||||
}
|
||||
delete(m.bridges, key)
|
||||
}
|
||||
}
|
||||
|
||||
func closeAllLSProxyBridgesForTest() {
|
||||
globalLSProxyBridgeManager.closeAll()
|
||||
}
|
||||
|
||||
func newLSProxyBridge(upstream *url.URL, logger *slog.Logger) (*lsProxyBridge, error) {
|
||||
dialer, err := proxy.FromURL(upstream, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SOCKS dialer: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen LS proxy bridge: %w", err)
|
||||
}
|
||||
|
||||
bridge := &lsProxyBridge{
|
||||
listener: listener,
|
||||
url: "http://" + listener.Addr().String(),
|
||||
upstream: upstream.Redacted(),
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(bridge.connectHandler(dialer, logger)),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
}
|
||||
bridge.server = server
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("LS proxy bridge serve failed", "upstream", bridge.upstream, "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("LS proxy bridge started", "upstream", bridge.upstream, "listen", bridge.url)
|
||||
go bridge.probeConnectivity(dialer, logger)
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) connectHandler(dialer proxy.Dialer, logger *slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodConnect {
|
||||
http.Error(w, "CONNECT only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
targetAddr := strings.TrimSpace(r.Host)
|
||||
if targetAddr == "" {
|
||||
targetAddr = strings.TrimSpace(r.URL.Host)
|
||||
}
|
||||
if targetAddr == "" {
|
||||
http.Error(w, "missing target host", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(targetAddr); err != nil {
|
||||
targetAddr = net.JoinHostPort(targetAddr, "443")
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
logger.Info("LS proxy bridge CONNECT", "upstream", b.upstream, "target", targetAddr)
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(r.Context(), lsProxyBridgeDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
targetConn, err := dialViaProxy(dialCtx, dialer, targetAddr)
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge dial failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
http.Error(w, "proxy dial failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
logger.Info("LS proxy bridge CONNECT established",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
_ = targetConn.Close()
|
||||
http.Error(w, "hijack unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
_ = targetConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if rw != nil && rw.Reader.Buffered() > 0 {
|
||||
if _, err := io.CopyN(targetConn, rw, int64(rw.Reader.Buffered())); err != nil {
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tunnelConns(clientConn, targetConn)
|
||||
}
|
||||
}
|
||||
|
||||
func dialViaProxy(ctx context.Context, dialer proxy.Dialer, targetAddr string) (net.Conn, error) {
|
||||
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
return contextDialer.DialContext(ctx, "tcp", targetAddr)
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialer.Dial("tcp", targetAddr)
|
||||
resultCh <- dialResult{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultCh:
|
||||
return result.conn, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *lsProxyBridge) probeConnectivity(dialer proxy.Dialer, logger *slog.Logger) {
|
||||
for _, targetAddr := range lsProxyBridgeProbeTargets {
|
||||
startedAt := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), lsProxyBridgeDialTimeout)
|
||||
conn, err := dialViaProxy(ctx, dialer, targetAddr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.Warn("LS proxy bridge probe failed",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond),
|
||||
"err", err)
|
||||
continue
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Info("LS proxy bridge probe succeeded",
|
||||
"upstream", b.upstream,
|
||||
"target", targetAddr,
|
||||
"elapsed", time.Since(startedAt).Truncate(time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
func tunnelConns(clientConn net.Conn, targetConn net.Conn) {
|
||||
var once sync.Once
|
||||
closeBoth := func() {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(targetConn, clientConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(clientConn, targetConn)
|
||||
once.Do(closeBoth)
|
||||
}()
|
||||
}
|
||||
|
||||
func readConnectResponse(br *bufio.Reader) (*http.Response, error) {
|
||||
return http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
||||
}
|
||||
193
backend/internal/pkg/lspool/proxy_bridge_test.go
Normal file
193
backend/internal/pkg/lspool/proxy_bridge_test.go
Normal file
@ -0,0 +1,193 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrepareLSProxyURLPassesThroughHTTPProxy(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
got, err := prepareLSProxyURL("http://proxy.example.com:8080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "http://proxy.example.com:8080", got)
|
||||
}
|
||||
|
||||
func TestPrepareLSProxyURLBridgesSOCKS5ForLS(t *testing.T) {
|
||||
t.Cleanup(closeAllLSProxyBridgesForTest)
|
||||
|
||||
targetAddr, closeTarget := startBridgeEchoServer(t)
|
||||
defer closeTarget()
|
||||
|
||||
socksURL, closeSOCKS := startBridgeSOCKS5Server(t)
|
||||
defer closeSOCKS()
|
||||
|
||||
bridgeURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasPrefix(bridgeURL, "http://127.0.0.1:"))
|
||||
|
||||
// Same SOCKS upstream should reuse the same local bridge.
|
||||
reusedURL, err := prepareLSProxyURL(socksURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bridgeURL, reusedURL)
|
||||
|
||||
bridgeAddr := strings.TrimPrefix(bridgeURL, "http://")
|
||||
conn, err := net.Dial("tcp", bridgeAddr)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
_, err = fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", targetAddr, targetAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, err := readConnectResponse(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
|
||||
_, err = conn.Write([]byte("ping"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reply := make([]byte, 4)
|
||||
_, err = io.ReadFull(reader, reply)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", string(reply))
|
||||
}
|
||||
|
||||
func startBridgeEchoServer(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
return
|
||||
}
|
||||
if string(buf) == "ping" {
|
||||
_, _ = c.Write([]byte("pong"))
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func startBridgeSOCKS5Server(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleBridgeSOCKS5Conn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return "socks5://" + ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func handleBridgeSOCKS5Conn(conn net.Conn) {
|
||||
header := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
methods := make([]byte, int(header[1]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte{0x05, 0x00})
|
||||
|
||||
reqHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
targetHost, ok := readSOCKS5Addr(conn, reqHeader[3])
|
||||
if !ok {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
portBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, portBuf); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
targetAddr := fmt.Sprintf("%s:%d", targetHost, binary.BigEndian.Uint16(portBuf))
|
||||
|
||||
targetConn, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
_, _ = conn.Write([]byte{0x05, 0x01, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||
tunnelConns(conn, targetConn)
|
||||
}
|
||||
|
||||
func readSOCKS5Addr(conn net.Conn, atyp byte) (string, bool) {
|
||||
switch atyp {
|
||||
case 0x01:
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
case 0x03:
|
||||
lenBuf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
buf := make([]byte, int(lenBuf[0]))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(buf), true
|
||||
case 0x04:
|
||||
buf := make([]byte, 16)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return "", false
|
||||
}
|
||||
return net.IP(buf).String(), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
138
backend/internal/pkg/lspool/proxy_exec.go
Normal file
138
backend/internal/pkg/lspool/proxy_exec.go
Normal file
@ -0,0 +1,138 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
)
|
||||
|
||||
type lsLaunchPlan struct {
|
||||
cmd *exec.Cmd
|
||||
effectiveProxyURL string
|
||||
proxyMode string
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func prepareLSLaunchPlan(binPath string, args []string, rawProxyURL string) (*lsLaunchPlan, error) {
|
||||
normalized, parsed, err := proxyurl.Parse(rawProxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan := &lsLaunchPlan{
|
||||
cmd: exec.Command(binPath, args...),
|
||||
proxyMode: "direct",
|
||||
}
|
||||
|
||||
if parsed == nil {
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
plan.effectiveProxyURL = normalized
|
||||
plan.proxyMode = "env-http-proxy"
|
||||
return plan, nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
if proxychainsPath, err := exec.LookPath("proxychains4"); err == nil {
|
||||
cfgPath, err := writeProxychainsConfig(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.cmd = exec.Command(proxychainsPath, append([]string{"-f", cfgPath, binPath}, args...)...)
|
||||
plan.proxyMode = "proxychains4"
|
||||
plan.cleanup = func() {
|
||||
_ = os.Remove(cfgPath)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
effectiveProxyURL, err := prepareLSProxyURL(normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan.effectiveProxyURL = effectiveProxyURL
|
||||
plan.proxyMode = "http-connect-bridge"
|
||||
return plan, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported LS proxy scheme: %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func writeProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
content, err := buildProxychainsConfig(proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
file, err := os.CreateTemp("", "sub2api-proxychains-*.conf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create proxychains config: %w", err)
|
||||
}
|
||||
|
||||
if _, err := file.WriteString(content); err != nil {
|
||||
_ = file.Close()
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("write proxychains config: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
_ = os.Remove(file.Name())
|
||||
return "", fmt.Errorf("close proxychains config: %w", err)
|
||||
}
|
||||
|
||||
return file.Name(), nil
|
||||
}
|
||||
|
||||
func buildProxychainsConfig(proxyURL *url.URL) (string, error) {
|
||||
if proxyURL == nil {
|
||||
return "", fmt.Errorf("proxy url is nil")
|
||||
}
|
||||
if scheme := strings.ToLower(proxyURL.Scheme); scheme != "socks5" && scheme != "socks5h" {
|
||||
return "", fmt.Errorf("proxychains only supports socks5/socks5h, got %s", proxyURL.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(proxyURL.Hostname())
|
||||
port := strings.TrimSpace(proxyURL.Port())
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("proxy host is empty")
|
||||
}
|
||||
if port == "" {
|
||||
port = "1080"
|
||||
}
|
||||
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
if strings.ContainsAny(username, " \t\r\n") || strings.ContainsAny(password, " \t\r\n") {
|
||||
return "", fmt.Errorf("proxychains credentials cannot contain whitespace")
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("strict_chain\n")
|
||||
builder.WriteString("proxy_dns\n")
|
||||
builder.WriteString("remote_dns_subnet 224\n")
|
||||
builder.WriteString("tcp_connect_time_out 8000\n")
|
||||
builder.WriteString("tcp_read_time_out 15000\n")
|
||||
builder.WriteString("localnet 127.0.0.0/255.0.0.0\n")
|
||||
builder.WriteString("localnet ::1/128\n")
|
||||
builder.WriteString("[ProxyList]\n")
|
||||
builder.WriteString("socks5 ")
|
||||
builder.WriteString(host)
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(port)
|
||||
if username != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(username)
|
||||
if password != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(password)
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
return builder.String(), nil
|
||||
}
|
||||
31
backend/internal/pkg/lspool/proxy_exec_test.go
Normal file
31
backend/internal/pkg/lspool/proxy_exec_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildProxychainsConfigIncludesAuthAndLocalBypass(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://testuser:testpass@192.0.2.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := buildProxychainsConfig(proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, cfg, "proxy_dns\n")
|
||||
require.Contains(t, cfg, "localnet 127.0.0.0/255.0.0.0\n")
|
||||
require.Contains(t, cfg, "localnet ::1/128\n")
|
||||
require.Contains(t, cfg, "[ProxyList]\n")
|
||||
require.Contains(t, cfg, "socks5 192.0.2.1 1080 testuser testpass\n")
|
||||
}
|
||||
|
||||
func TestBuildProxychainsConfigRejectsWhitespaceCredentials(t *testing.T) {
|
||||
proxyURL, err := url.Parse("socks5h://user:bad%20pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = buildProxychainsConfig(proxyURL)
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "whitespace"))
|
||||
}
|
||||
99
backend/internal/pkg/lspool/remote_instance.go
Normal file
99
backend/internal/pkg/lspool/remote_instance.go
Normal file
@ -0,0 +1,99 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (i *Instance) callWorkerUnary(ctx context.Context, service, method, mode string, body []byte) ([]byte, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/unary", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker rpc read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return respBody, fmt.Errorf("worker rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (i *Instance) callWorkerStream(ctx context.Context, service, method, mode string, body []byte) (*http.Response, error) {
|
||||
endpoint, err := i.workerEndpoint("/rpc/stream", service, method, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", i.workerToken)
|
||||
if mode == "json" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
|
||||
resp, err := i.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s: %w", service, method, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("worker stream rpc %s/%s HTTP %d: %s", service, method, resp.StatusCode, truncate(string(body), 200))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (i *Instance) workerEndpoint(path, service, method, mode string) (string, error) {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: i.Address,
|
||||
Path: path,
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Set("service", service)
|
||||
values.Set("method", method)
|
||||
values.Set("mode", mode)
|
||||
if i.routingKey != "" {
|
||||
values.Set("routing_key", i.routingKey)
|
||||
}
|
||||
base.RawQuery = values.Encode()
|
||||
return base.String(), nil
|
||||
}
|
||||
|
||||
func marshalWorkerJSONBody(input any) ([]byte, error) {
|
||||
if input == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
1682
backend/internal/pkg/lspool/upstream_adapter.go
Normal file
1682
backend/internal/pkg/lspool/upstream_adapter.go
Normal file
File diff suppressed because it is too large
Load Diff
680
backend/internal/pkg/lspool/worker_manager.go
Normal file
680
backend/internal/pkg/lspool/worker_manager.go
Normal file
@ -0,0 +1,680 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
lsWorkerManagedByLabel = "managed-by"
|
||||
lsWorkerManagedByValue = "sub2api"
|
||||
lsWorkerAccountLabel = "account_id"
|
||||
lsWorkerProxyHashLabel = "proxy_hash"
|
||||
lsWorkerImageTagLabel = "image_tag"
|
||||
lsWorkerControlPort = 18081
|
||||
)
|
||||
|
||||
type workerManagerConfig struct {
|
||||
Image string
|
||||
Network string
|
||||
DockerSocket string
|
||||
IdleTTL time.Duration
|
||||
MaxActive int
|
||||
StartupTimeout time.Duration
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
type dockerClient interface {
|
||||
ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error)
|
||||
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error)
|
||||
ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error
|
||||
ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error)
|
||||
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
|
||||
ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type workerManager struct {
|
||||
cfg workerManagerConfig
|
||||
docker dockerClient
|
||||
http *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
workers map[string]*workerHandle
|
||||
state map[string]*workerAccountState
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type workerHandle struct {
|
||||
Key string
|
||||
AccountID string
|
||||
ProxyURL string
|
||||
ProxyHash string
|
||||
ContainerID string
|
||||
Container string
|
||||
Address string
|
||||
AuthToken string
|
||||
LastUsed time.Time
|
||||
LastStateSHA string
|
||||
}
|
||||
|
||||
type workerAccountState struct {
|
||||
HasToken bool `json:"has_token"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
HasModelCredits bool `json:"has_model_credits"`
|
||||
UseAICredits bool `json:"use_ai_credits"`
|
||||
AvailableCredits *int32 `json:"available_credits,omitempty"`
|
||||
MinimumCreditAmount *int32 `json:"minimum_credit_amount,omitempty"`
|
||||
}
|
||||
|
||||
func NewWorkerManagerFromConfig(cfg *config.Config) (Backend, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
managerCfg := workerManagerConfig{
|
||||
Image: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Image),
|
||||
Network: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.Network),
|
||||
DockerSocket: strings.TrimSpace(cfg.Gateway.AntigravityLSWorker.DockerSocket),
|
||||
IdleTTL: cfg.Gateway.AntigravityLSWorker.IdleTTL,
|
||||
MaxActive: cfg.Gateway.AntigravityLSWorker.MaxActive,
|
||||
StartupTimeout: cfg.Gateway.AntigravityLSWorker.StartupTimeout,
|
||||
RequestTimeout: cfg.Gateway.AntigravityLSWorker.RequestTimeout,
|
||||
}
|
||||
|
||||
if managerCfg.Image == "" {
|
||||
managerCfg.Image = "weishaw/sub2api-lsworker:latest"
|
||||
}
|
||||
if managerCfg.Network == "" {
|
||||
managerCfg.Network = "sub2api-network"
|
||||
}
|
||||
if managerCfg.DockerSocket == "" {
|
||||
managerCfg.DockerSocket = "unix:///var/run/docker.sock"
|
||||
}
|
||||
if managerCfg.IdleTTL <= 0 {
|
||||
managerCfg.IdleTTL = 15 * time.Minute
|
||||
}
|
||||
if managerCfg.MaxActive < 1 {
|
||||
managerCfg.MaxActive = 50
|
||||
}
|
||||
if managerCfg.StartupTimeout <= 0 {
|
||||
managerCfg.StartupTimeout = 45 * time.Second
|
||||
}
|
||||
if managerCfg.RequestTimeout <= 0 {
|
||||
managerCfg.RequestTimeout = 60 * time.Second
|
||||
}
|
||||
|
||||
opts := []client.Opt{client.WithAPIVersionNegotiation()}
|
||||
if managerCfg.DockerSocket != "" {
|
||||
opts = append(opts, client.WithHost(managerCfg.DockerSocket))
|
||||
} else {
|
||||
opts = append(opts, client.FromEnv)
|
||||
}
|
||||
|
||||
dockerClient, err := client.NewClientWithOpts(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create docker client: %w", err)
|
||||
}
|
||||
|
||||
return newWorkerManager(managerCfg, dockerClient)
|
||||
}
|
||||
|
||||
func newWorkerManager(cfg workerManagerConfig, docker dockerClient) (*workerManager, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
mgr := &workerManager{
|
||||
cfg: cfg,
|
||||
docker: docker,
|
||||
http: &http.Client{
|
||||
Timeout: cfg.RequestTimeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConnsPerHost: 8,
|
||||
},
|
||||
},
|
||||
workers: make(map[string]*workerHandle),
|
||||
state: make(map[string]*workerAccountState),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: slog.Default().With("component", "lspool-worker-manager"),
|
||||
}
|
||||
if err := mgr.reconcileManagedContainers(ctx); err != nil {
|
||||
cancel()
|
||||
_ = docker.Close()
|
||||
return nil, err
|
||||
}
|
||||
go mgr.cleanupLoop()
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) Close() {
|
||||
m.cancel()
|
||||
|
||||
m.mu.Lock()
|
||||
workers := make([]*workerHandle, 0, len(m.workers))
|
||||
for _, handle := range m.workers {
|
||||
workers = append(workers, handle)
|
||||
}
|
||||
m.workers = make(map[string]*workerHandle)
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range workers {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
if m.docker != nil {
|
||||
_ = m.docker.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) Stats() map[string]any {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return map[string]any{
|
||||
"accounts": len(m.state),
|
||||
"total": len(m.workers),
|
||||
"active": len(m.workers),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasToken = true
|
||||
state.AccessToken = accessToken
|
||||
state.RefreshToken = refreshToken
|
||||
if expiresAt.IsZero() {
|
||||
state.ExpiresAt = nil
|
||||
} else {
|
||||
ts := expiresAt.UTC()
|
||||
state.ExpiresAt = &ts
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
state := m.ensureStateLocked(accountID)
|
||||
state.HasModelCredits = true
|
||||
state.UseAICredits = useAICredits
|
||||
state.AvailableCredits = cloneInt32Ptr(availableCredits)
|
||||
state.MinimumCreditAmount = cloneInt32Ptr(minimumCreditAmountForUsage)
|
||||
}
|
||||
|
||||
func (m *workerManager) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
normalizedProxy, parsedProxy, err := resolveWorkerProxy(rawProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedProxy == nil {
|
||||
return nil, fmt.Errorf("ls worker requires a socks5/socks5h proxy for account %s", accountID)
|
||||
}
|
||||
|
||||
replica := replicaSlotIndex(routingKey, parseLSReplicaCount())
|
||||
proxyHash := proxyHash(normalizedProxy)
|
||||
workerKey := buildWorkerKey(accountID, proxyHash)
|
||||
|
||||
m.mu.Lock()
|
||||
state := cloneWorkerAccountState(m.state[accountID])
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker missing access token for account %s", accountID)
|
||||
}
|
||||
|
||||
handle := m.workers[workerKey]
|
||||
if handle == nil {
|
||||
if len(m.workers) >= m.cfg.MaxActive {
|
||||
m.mu.Unlock()
|
||||
return nil, fmt.Errorf("ls worker limit reached (%d active)", m.cfg.MaxActive)
|
||||
}
|
||||
handle, err = m.createWorkerLocked(accountID, normalizedProxy, proxyHash, parsedProxy)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
m.workers[workerKey] = handle
|
||||
}
|
||||
handle.LastUsed = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
if err := m.waitForWorkerHealthy(handle); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.syncWorkerState(handle, state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.waitForWorkerReady(handle, routingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inst := &Instance{
|
||||
AccountID: accountID,
|
||||
Replica: replica,
|
||||
Address: handle.Address,
|
||||
client: m.http,
|
||||
healthy: true,
|
||||
lastUsed: time.Now(),
|
||||
modelMapReady: 1,
|
||||
remote: true,
|
||||
workerToken: handle.AuthToken,
|
||||
routingKey: routingKey,
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (m *workerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.collectIdleWorkers()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) collectIdleWorkers() {
|
||||
now := time.Now()
|
||||
var expired []*workerHandle
|
||||
|
||||
m.mu.Lock()
|
||||
for key, handle := range m.workers {
|
||||
if handle == nil {
|
||||
delete(m.workers, key)
|
||||
continue
|
||||
}
|
||||
if now.Sub(handle.LastUsed) <= m.cfg.IdleTTL {
|
||||
continue
|
||||
}
|
||||
expired = append(expired, handle)
|
||||
delete(m.workers, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, handle := range expired {
|
||||
m.removeWorkerContainer(context.Background(), handle)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) reconcileManagedContainers(ctx context.Context) error {
|
||||
args := filters.NewArgs()
|
||||
args.Add("label", fmt.Sprintf("%s=%s", lsWorkerManagedByLabel, lsWorkerManagedByValue))
|
||||
|
||||
containers, err := m.docker.ContainerList(ctx, container.ListOptions{
|
||||
All: true,
|
||||
Filters: args,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list managed ls workers: %w", err)
|
||||
}
|
||||
|
||||
for _, summary := range containers {
|
||||
handle := &workerHandle{
|
||||
ContainerID: summary.ID,
|
||||
Container: strings.TrimPrefix(firstContainerName(summary.Names), "/"),
|
||||
}
|
||||
if err := m.removeWorkerContainer(ctx, handle); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) createWorkerLocked(accountID, proxyURL, proxyHash string, parsedProxy *url.URL) (*workerHandle, error) {
|
||||
containerName := fmt.Sprintf("sub2api-ls-%s-%s", accountID, proxyHash[:8])
|
||||
authToken := generateUUID()
|
||||
|
||||
proxyHost := parsedProxy.Hostname()
|
||||
proxyPort := parsedProxy.Port()
|
||||
if proxyPort == "" {
|
||||
proxyPort = "1080"
|
||||
}
|
||||
proxyUser := parsedProxy.User.Username()
|
||||
proxyPass, _ := parsedProxy.User.Password()
|
||||
|
||||
labels := map[string]string{
|
||||
lsWorkerManagedByLabel: lsWorkerManagedByValue,
|
||||
lsWorkerAccountLabel: accountID,
|
||||
lsWorkerProxyHashLabel: proxyHash,
|
||||
lsWorkerImageTagLabel: m.cfg.Image,
|
||||
}
|
||||
|
||||
env := []string{
|
||||
"ANTIGRAVITY_APP_ROOT=/app/ls",
|
||||
fmt.Sprintf("LSWORKER_ACCOUNT_ID=%s", accountID),
|
||||
fmt.Sprintf("LSWORKER_AUTH_TOKEN=%s", authToken),
|
||||
fmt.Sprintf("LSWORKER_LISTEN_ADDR=0.0.0.0:%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("LSWORKER_NETWORK_READY_FILE=%s", "/run/lsworker/network-ready"),
|
||||
fmt.Sprintf("LSWORKER_PROXY_URL=%s", proxyURL),
|
||||
fmt.Sprintf("LSWORKER_PROXY_HOST=%s", proxyHost),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PORT=%s", proxyPort),
|
||||
fmt.Sprintf("LSWORKER_PROXY_USER=%s", proxyUser),
|
||||
fmt.Sprintf("LSWORKER_PROXY_PASS=%s", proxyPass),
|
||||
fmt.Sprintf("LSWORKER_CONTROL_PORT=%d", lsWorkerControlPort),
|
||||
fmt.Sprintf("ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=%d", parseLSReplicaCount()),
|
||||
}
|
||||
if tz := strings.TrimSpace(os.Getenv("TZ")); tz != "" {
|
||||
env = append(env, "TZ="+tz)
|
||||
}
|
||||
|
||||
createResp, err := m.docker.ContainerCreate(m.ctx, &container.Config{
|
||||
Image: m.cfg.Image,
|
||||
Labels: labels,
|
||||
Env: env,
|
||||
}, &container.HostConfig{
|
||||
CapAdd: []string{"NET_ADMIN"},
|
||||
}, &network.NetworkingConfig{
|
||||
EndpointsConfig: map[string]*network.EndpointSettings{
|
||||
m.cfg.Network: {},
|
||||
},
|
||||
}, nil, containerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ls worker container: %w", err)
|
||||
}
|
||||
|
||||
if err := m.docker.ContainerStart(m.ctx, createResp.ID, container.StartOptions{}); err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("start ls worker container: %w", err)
|
||||
}
|
||||
|
||||
inspect, err := m.docker.ContainerInspect(m.ctx, createResp.ID)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, fmt.Errorf("inspect ls worker container: %w", err)
|
||||
}
|
||||
|
||||
address, err := workerAddressFromInspect(inspect, m.cfg.Network)
|
||||
if err != nil {
|
||||
_ = m.docker.ContainerRemove(m.ctx, createResp.ID, container.RemoveOptions{Force: true})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.logger.Info("created ls worker",
|
||||
"account", shortAccountID(accountID),
|
||||
"container", containerName,
|
||||
"address", address,
|
||||
"proxy_hash", proxyHash[:8])
|
||||
|
||||
return &workerHandle{
|
||||
Key: buildWorkerKey(accountID, proxyHash),
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: proxyHash,
|
||||
ContainerID: createResp.ID,
|
||||
Container: containerName,
|
||||
Address: address,
|
||||
AuthToken: authToken,
|
||||
LastUsed: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func workerAddressFromInspect(inspect container.InspectResponse, networkName string) (string, error) {
|
||||
if inspect.NetworkSettings == nil {
|
||||
return "", fmt.Errorf("ls worker inspect missing network settings")
|
||||
}
|
||||
if endpoint, ok := inspect.NetworkSettings.Networks[networkName]; ok && endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
for _, endpoint := range inspect.NetworkSettings.Networks {
|
||||
if endpoint != nil && strings.TrimSpace(endpoint.IPAddress) != "" {
|
||||
return net.JoinHostPort(strings.TrimSpace(endpoint.IPAddress), strconv.Itoa(lsWorkerControlPort)), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("ls worker missing IP address on network %s", networkName)
|
||||
}
|
||||
|
||||
func firstContainerName(names []string) string {
|
||||
if len(names) == 0 {
|
||||
return ""
|
||||
}
|
||||
return names[0]
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerHealthy(handle *workerHandle) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/healthz", nil), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("worker %s failed health check: %w", handle.Container, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *workerManager) waitForWorkerReady(handle *workerHandle, routingKey string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.StartupTimeout)
|
||||
defer cancel()
|
||||
|
||||
values := url.Values{}
|
||||
if strings.TrimSpace(routingKey) != "" {
|
||||
values.Set("routing_key", routingKey)
|
||||
}
|
||||
|
||||
var (
|
||||
lastStatus int
|
||||
lastBody string
|
||||
)
|
||||
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, workerURL(handle, "/readyz", values), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
resp, err := m.http.Do(req)
|
||||
if err == nil {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
lastStatus = resp.StatusCode
|
||||
lastBody = truncate(string(body), 200)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
if isWorkerModelMappingUnavailable(resp.StatusCode, lastBody) {
|
||||
return fmt.Errorf("%w: worker %s %s", errLSModelMapDenied, handle.Container, strings.TrimSpace(lastBody))
|
||||
}
|
||||
if len(body) > 0 && shouldWarnWorkerNotReady(resp.StatusCode, lastBody) {
|
||||
m.logger.Warn("ls worker not ready yet", "container", handle.Container, "status", resp.StatusCode, "body", truncate(string(body), 200))
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastStatus > 0 || lastBody != "" {
|
||||
return fmt.Errorf("worker %s not ready for routing key %q (last_status=%d last_body=%q): %w", handle.Container, routingKey, lastStatus, lastBody, ctx.Err())
|
||||
}
|
||||
return fmt.Errorf("worker %s not ready for routing key %q: %w", handle.Container, routingKey, ctx.Err())
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldWarnWorkerNotReady(status int, body string) bool {
|
||||
if status == http.StatusServiceUnavailable {
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
if strings.Contains(normalized, "model mapping not ready") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isWorkerModelMappingUnavailable(status int, body string) bool {
|
||||
if status != http.StatusServiceUnavailable {
|
||||
return false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(body))
|
||||
return strings.Contains(normalized, errLSModelMapDenied.Error())
|
||||
}
|
||||
|
||||
func (m *workerManager) syncWorkerState(handle *workerHandle, state *workerAccountState) error {
|
||||
if state == nil {
|
||||
return fmt.Errorf("ls worker state is nil")
|
||||
}
|
||||
body, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker state: %w", err)
|
||||
}
|
||||
|
||||
sum := fmt.Sprintf("%x", sha256.Sum256(body))
|
||||
if handle.LastStateSHA == sum {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.cfg.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, workerURL(handle, "/account/state", nil), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Worker-Token", handle.AuthToken)
|
||||
|
||||
resp, err := m.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sync worker state: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("sync worker state HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 200))
|
||||
}
|
||||
handle.LastStateSHA = sum
|
||||
return nil
|
||||
}
|
||||
|
||||
func workerURL(handle *workerHandle, path string, values url.Values) string {
|
||||
base := url.URL{
|
||||
Scheme: "http",
|
||||
Host: handle.Address,
|
||||
Path: path,
|
||||
}
|
||||
if values != nil {
|
||||
base.RawQuery = values.Encode()
|
||||
}
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (m *workerManager) removeWorkerContainer(ctx context.Context, handle *workerHandle) error {
|
||||
if handle == nil || strings.TrimSpace(handle.ContainerID) == "" {
|
||||
return nil
|
||||
}
|
||||
timeout := 5
|
||||
_ = m.docker.ContainerStop(ctx, handle.ContainerID, container.StopOptions{Timeout: &timeout})
|
||||
if err := m.docker.ContainerRemove(ctx, handle.ContainerID, container.RemoveOptions{Force: true}); err != nil {
|
||||
return fmt.Errorf("remove ls worker container %s: %w", handle.ContainerID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *workerManager) ensureStateLocked(accountID string) *workerAccountState {
|
||||
state := m.state[accountID]
|
||||
if state == nil {
|
||||
state = &workerAccountState{}
|
||||
m.state[accountID] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func resolveWorkerProxy(proxyURL string) (string, *url.URL, error) {
|
||||
resolved := resolveLSProxy(proxyURL)
|
||||
normalized, parsed, err := proxyurl.Parse(resolved)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "socks5", "socks5h":
|
||||
return normalized, parsed, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("ls worker only supports socks5/socks5h proxies, got %s", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func proxyHash(proxyURL string) string {
|
||||
if strings.TrimSpace(proxyURL) == "" {
|
||||
return "direct"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(proxyURL)))
|
||||
return fmt.Sprintf("%x", sum[:])
|
||||
}
|
||||
|
||||
func buildWorkerKey(accountID, proxyHash string) string {
|
||||
return accountID + ":" + proxyHash
|
||||
}
|
||||
|
||||
func cloneInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func cloneWorkerAccountState(state *workerAccountState) *workerAccountState {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *state
|
||||
cp.AvailableCredits = cloneInt32Ptr(state.AvailableCredits)
|
||||
cp.MinimumCreditAmount = cloneInt32Ptr(state.MinimumCreditAmount)
|
||||
if state.ExpiresAt != nil {
|
||||
ts := *state.ExpiresAt
|
||||
cp.ExpiresAt = &ts
|
||||
}
|
||||
return &cp
|
||||
}
|
||||
335
backend/internal/pkg/lspool/worker_manager_test.go
Normal file
335
backend/internal/pkg/lspool/worker_manager_test.go
Normal file
@ -0,0 +1,335 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeDockerClient struct {
|
||||
mu sync.Mutex
|
||||
|
||||
listResp []container.Summary
|
||||
listCalls int
|
||||
createCalls int
|
||||
startCalls int
|
||||
stopCalls int
|
||||
removeCalls int
|
||||
inspectCalls int
|
||||
removedIDs []string
|
||||
createdConfigs []*container.Config
|
||||
inspectResp container.InspectResponse
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.listCalls++
|
||||
return append([]container.Summary(nil), f.listResp...), nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerCreate(ctx context.Context, cfg *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, containerName string) (container.CreateResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.createCalls++
|
||||
f.createdConfigs = append(f.createdConfigs, cfg)
|
||||
return container.CreateResponse{ID: "worker-created"}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.startCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerInspect(ctx context.Context, containerID string) (container.InspectResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.inspectCalls++
|
||||
return f.inspectResp, nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.stopCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.removeCalls++
|
||||
f.removedIDs = append(f.removedIDs, containerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerClient) Close() error { return nil }
|
||||
|
||||
func TestResolveWorkerProxyRejectsHTTP(t *testing.T) {
|
||||
_, _, err := resolveWorkerProxy("http://127.0.0.1:7890")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "only supports socks5/socks5h")
|
||||
}
|
||||
|
||||
func TestProxyHashUsesNormalizedProxy(t *testing.T) {
|
||||
normalized, _, err := resolveWorkerProxy("socks5://user:pass@127.0.0.1:1080")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "socks5h://user:pass@127.0.0.1:1080", normalized)
|
||||
|
||||
hash1 := proxyHash(normalized)
|
||||
hash2 := proxyHash("socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestWorkerManagerRequiresToken(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 2,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-1", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing access token")
|
||||
}
|
||||
|
||||
func TestWorkerManagerReusesExistingHandleAndDedupesStateSync(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var healthCalls int
|
||||
var readyCalls int
|
||||
var stateCalls int
|
||||
var stateBodies [][]byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/healthz":
|
||||
mu.Lock()
|
||||
healthCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
case "/readyz":
|
||||
mu.Lock()
|
||||
readyCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
case "/account/state":
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
mu.Lock()
|
||||
stateCalls++
|
||||
stateBodies = append(stateBodies, body)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
accountID := "9"
|
||||
proxyURL := "socks5h://user:pass@127.0.0.1:1080"
|
||||
hash := proxyHash(proxyURL)
|
||||
key := buildWorkerKey(accountID, hash)
|
||||
|
||||
manager.SetAccountToken(accountID, "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers[key] = &workerHandle{
|
||||
Key: key,
|
||||
AccountID: accountID,
|
||||
ProxyURL: proxyURL,
|
||||
ProxyHash: hash,
|
||||
ContainerID: "existing-worker",
|
||||
Container: "sub2api-ls-9-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
inst1, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst1.remote)
|
||||
require.Equal(t, replicaSlotIndex("rk-1", parseLSReplicaCount()), inst1.Replica)
|
||||
|
||||
inst2, err := manager.GetOrCreate(accountID, "rk-1", proxyURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, inst2.remote)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.GreaterOrEqual(t, healthCalls, 2)
|
||||
require.GreaterOrEqual(t, readyCalls, 2)
|
||||
require.Equal(t, 1, stateCalls, "state sync should be skipped when the payload hash is unchanged")
|
||||
require.Len(t, stateBodies, 1)
|
||||
|
||||
var synced workerAccountState
|
||||
require.NoError(t, json.Unmarshal(stateBodies[0], &synced))
|
||||
require.True(t, synced.HasToken)
|
||||
require.Equal(t, "ya29.test", synced.AccessToken)
|
||||
}
|
||||
|
||||
func TestWorkerManagerMaxActiveStopsNewWorkerCreation(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
manager.SetAccountToken("9", "ya29.test", "refresh", time.Now().Add(time.Hour))
|
||||
manager.mu.Lock()
|
||||
manager.workers["existing"] = &workerHandle{ContainerID: "existing", Container: "existing", LastUsed: time.Now()}
|
||||
manager.mu.Unlock()
|
||||
|
||||
_, err = manager.GetOrCreate("9", "rk-new", "socks5h://user:pass@127.0.0.1:1080")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "limit reached")
|
||||
require.Equal(t, 0, fakeDocker.createCalls)
|
||||
}
|
||||
|
||||
func TestWorkerManagerReconcileRemovesManagedContainers(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{
|
||||
listResp: []container.Summary{
|
||||
{
|
||||
ID: "old-worker-1",
|
||||
Names: []string{"/sub2api-ls-9-deadbeef"},
|
||||
},
|
||||
{
|
||||
ID: "old-worker-2",
|
||||
Names: []string{"/sub2api-ls-10-beadfeed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 4,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, fakeDocker)
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
require.Equal(t, 1, fakeDocker.listCalls)
|
||||
require.ElementsMatch(t, []string{"old-worker-1", "old-worker-2"}, fakeDocker.removedIDs)
|
||||
}
|
||||
|
||||
func TestFakeDockerClientImplementsFilterAwareList(t *testing.T) {
|
||||
fakeDocker := &fakeDockerClient{}
|
||||
_, err := fakeDocker.ContainerList(context.Background(), container.ListOptions{Filters: filters.NewArgs()})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestShouldWarnWorkerNotReadySuppressesModelMappingPending(t *testing.T) {
|
||||
require.False(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker model mapping not ready for replica 0"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusServiceUnavailable, "worker access token not configured"))
|
||||
require.True(t, shouldWarnWorkerNotReady(http.StatusBadGateway, "upstream failed"))
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyStopsOnModelMappingUnavailable(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte(`model mapping unavailable for replica 0: oauth2: "unauthorized_client" "Unauthorized"`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: time.Second,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errLSModelMapDenied)
|
||||
}
|
||||
|
||||
func TestWorkerManagerWaitForWorkerReadyIncludesLastBodyOnTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/readyz", r.URL.Path)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, _ = w.Write([]byte("worker model mapping not ready for replica 0\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
manager, err := newWorkerManager(workerManagerConfig{
|
||||
Image: "worker:latest",
|
||||
Network: "sub2api-network",
|
||||
DockerSocket: "unix:///var/run/docker.sock",
|
||||
IdleTTL: time.Minute,
|
||||
MaxActive: 1,
|
||||
StartupTimeout: 100 * time.Millisecond,
|
||||
RequestTimeout: time.Second,
|
||||
}, &fakeDockerClient{})
|
||||
require.NoError(t, err)
|
||||
defer manager.Close()
|
||||
|
||||
handle := &workerHandle{
|
||||
Container: "sub2api-ls-test",
|
||||
Address: strings.TrimPrefix(server.URL, "http://"),
|
||||
AuthToken: "worker-token",
|
||||
}
|
||||
|
||||
err = manager.waitForWorkerReady(handle, "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `last_status=503`)
|
||||
require.Contains(t, err.Error(), `last_body="worker model mapping not ready for replica 0`)
|
||||
}
|
||||
374
backend/internal/pkg/lspool/worker_server.go
Normal file
374
backend/internal/pkg/lspool/worker_server.go
Normal file
@ -0,0 +1,374 @@
|
||||
package lspool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WorkerServerConfig struct {
|
||||
AccountID string
|
||||
AuthToken string
|
||||
ListenAddr string
|
||||
AppRoot string
|
||||
NetworkReadyFile string
|
||||
MaxIdleTime time.Duration
|
||||
HealthInterval time.Duration
|
||||
}
|
||||
|
||||
type WorkerServer struct {
|
||||
cfg WorkerServerConfig
|
||||
pool *Pool
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
state workerAccountState
|
||||
}
|
||||
|
||||
func NewWorkerServer(cfg WorkerServerConfig) (*WorkerServer, error) {
|
||||
if strings.TrimSpace(cfg.AccountID) == "" {
|
||||
return nil, fmt.Errorf("worker account id is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.AuthToken) == "" {
|
||||
return nil, fmt.Errorf("worker auth token is required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
||||
cfg.ListenAddr = fmt.Sprintf("0.0.0.0:%d", lsWorkerControlPort)
|
||||
}
|
||||
if strings.TrimSpace(cfg.AppRoot) == "" {
|
||||
cfg.AppRoot = "/app/ls"
|
||||
}
|
||||
if cfg.MaxIdleTime <= 0 {
|
||||
cfg.MaxIdleTime = 15 * time.Minute
|
||||
}
|
||||
if cfg.HealthInterval <= 0 {
|
||||
cfg.HealthInterval = 30 * time.Second
|
||||
}
|
||||
|
||||
poolCfg := DefaultConfig()
|
||||
poolCfg.AppRoot = cfg.AppRoot
|
||||
poolCfg.MaxIdleTime = cfg.MaxIdleTime
|
||||
poolCfg.HealthCheckInterval = cfg.HealthInterval
|
||||
|
||||
return &WorkerServer{
|
||||
cfg: cfg,
|
||||
pool: NewPool(poolCfg),
|
||||
logger: slog.Default().With("component", "lsworker"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewWorkerServerFromEnv() (*WorkerServer, error) {
|
||||
maxIdleTime := 15 * time.Minute
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_MAX_IDLE_TIME")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
maxIdleTime = parsed
|
||||
}
|
||||
}
|
||||
healthInterval := 30 * time.Second
|
||||
if raw := strings.TrimSpace(os.Getenv("LSWORKER_POOL_HEALTH_INTERVAL")); raw != "" {
|
||||
if parsed, err := time.ParseDuration(raw); err == nil {
|
||||
healthInterval = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return NewWorkerServer(WorkerServerConfig{
|
||||
AccountID: strings.TrimSpace(os.Getenv("LSWORKER_ACCOUNT_ID")),
|
||||
AuthToken: strings.TrimSpace(os.Getenv("LSWORKER_AUTH_TOKEN")),
|
||||
ListenAddr: strings.TrimSpace(os.Getenv("LSWORKER_LISTEN_ADDR")),
|
||||
AppRoot: strings.TrimSpace(os.Getenv("ANTIGRAVITY_APP_ROOT")),
|
||||
NetworkReadyFile: strings.TrimSpace(os.Getenv("LSWORKER_NETWORK_READY_FILE")),
|
||||
MaxIdleTime: maxIdleTime,
|
||||
HealthInterval: healthInterval,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Close() {
|
||||
if s.pool != nil {
|
||||
s.pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WorkerServer) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", s.handleHealthz)
|
||||
mux.HandleFunc("/readyz", s.handleReadyz)
|
||||
mux.HandleFunc("/account/state", s.handleAccountState)
|
||||
mux.HandleFunc("/rpc/unary", s.handleRPCUnary)
|
||||
mux.HandleFunc("/rpc/stream", s.handleRPCStream)
|
||||
return mux
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleHealthz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleReadyz(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
routingKey := strings.TrimSpace(r.URL.Query().Get("routing_key"))
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(fmt.Sprintf("ready replica=%d", inst.Replica)))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleAccountState(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var payload workerAccountState
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "invalid account state payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.state = *cloneWorkerAccountState(&payload)
|
||||
s.mu.Unlock()
|
||||
s.applyState()
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCUnary(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
respBody, err = inst.CallUnaryJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
respBody, err = inst.CallRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(respBody)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) handleRPCStream(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.authorize(w, r) {
|
||||
return
|
||||
}
|
||||
service, method, mode, routingKey, ok := parseRPCRequest(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
inst, err := s.ensureReady(r.Context(), routingKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "read request body failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
switch mode {
|
||||
case "json":
|
||||
var input any
|
||||
if len(body) == 0 {
|
||||
body = []byte("{}")
|
||||
}
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
http.Error(w, "invalid json rpc body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp, err = inst.StreamRPCJSON(r.Context(), service, method, input)
|
||||
case "proto":
|
||||
resp, err = inst.StreamRPC(r.Context(), service, method, body)
|
||||
default:
|
||||
http.Error(w, "unsupported rpc mode", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (s *WorkerServer) authorize(w http.ResponseWriter, r *http.Request) bool {
|
||||
if subtleHeaderEqual(r.Header.Get("X-Worker-Token"), s.cfg.AuthToken) {
|
||||
return true
|
||||
}
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return false
|
||||
}
|
||||
|
||||
func subtleHeaderEqual(left, right string) bool {
|
||||
if left == "" || right == "" {
|
||||
return false
|
||||
}
|
||||
return left == right
|
||||
}
|
||||
|
||||
func parseRPCRequest(w http.ResponseWriter, r *http.Request) (service, method, mode, routingKey string, ok bool) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
query := r.URL.Query()
|
||||
service = strings.TrimSpace(query.Get("service"))
|
||||
method = strings.TrimSpace(query.Get("method"))
|
||||
mode = strings.ToLower(strings.TrimSpace(query.Get("mode")))
|
||||
routingKey = strings.TrimSpace(query.Get("routing_key"))
|
||||
if service == "" || method == "" {
|
||||
http.Error(w, "missing rpc target", http.StatusBadRequest)
|
||||
return "", "", "", "", false
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "proto"
|
||||
}
|
||||
return service, method, mode, routingKey, true
|
||||
}
|
||||
|
||||
func (s *WorkerServer) ensureReady(ctx context.Context, routingKey string) (*Instance, error) {
|
||||
if path := strings.TrimSpace(s.cfg.NetworkReadyFile); path != "" {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, fmt.Errorf("worker network not ready: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.applyState()
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil || !state.HasToken || strings.TrimSpace(state.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("worker access token not configured")
|
||||
}
|
||||
|
||||
inst, err := s.pool.GetOrCreate(s.cfg.AccountID, routingKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
if inst.HasModelMappingReady() {
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
modelCtx, cancel := context.WithTimeout(ctx, lsModelConfigTimeout)
|
||||
defer cancel()
|
||||
_ = modelCtx
|
||||
if !RefreshModelMapping(inst) {
|
||||
if inst.HasModelMappingUnavailable() {
|
||||
return nil, fmt.Errorf("%w for replica %d: %s", errLSModelMapDenied, inst.Replica, inst.ModelMappingUnavailableReason())
|
||||
}
|
||||
return nil, fmt.Errorf("worker model mapping not ready for replica %d", inst.Replica)
|
||||
}
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (s *WorkerServer) applyState() {
|
||||
s.mu.RLock()
|
||||
state := cloneWorkerAccountState(&s.state)
|
||||
s.mu.RUnlock()
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
if state.HasToken {
|
||||
expiresAt := time.Time{}
|
||||
if state.ExpiresAt != nil {
|
||||
expiresAt = state.ExpiresAt.UTC()
|
||||
}
|
||||
s.pool.SetAccountToken(s.cfg.AccountID, state.AccessToken, state.RefreshToken, expiresAt)
|
||||
}
|
||||
if state.HasModelCredits {
|
||||
s.pool.SetAccountModelCredits(s.cfg.AccountID, state.UseAICredits, state.AvailableCredits, state.MinimumCreditAmount)
|
||||
}
|
||||
}
|
||||
|
||||
func workerHTTPServer(listenAddr string, handler http.Handler) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func workerExitCode(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func parseWorkerControlPort() int {
|
||||
raw := strings.TrimSpace(os.Getenv("LSWORKER_CONTROL_PORT"))
|
||||
if raw == "" {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
port, err := strconv.Atoi(raw)
|
||||
if err != nil || port < 1 {
|
||||
return lsWorkerControlPort
|
||||
}
|
||||
return port
|
||||
}
|
||||
@ -13,9 +13,12 @@ import (
|
||||
)
|
||||
|
||||
// allowedSchemes 代理协议白名单
|
||||
// 注意: https 代理已被移除。当前实现(Go dialer.go 和 Node proxy.js)
|
||||
// 对 https:// 代理仅做 TCP 连接后发明文 CONNECT,不建立外层 TLS,
|
||||
// 导致 Proxy-Authorization 凭据在首跳明文传输。
|
||||
// 若需 https 代理支持,须先在 dialer.go 和 proxy.js 中实现 TLS-to-proxy。
|
||||
var allowedSchemes = map[string]bool{
|
||||
"http": true,
|
||||
"https": true,
|
||||
"socks5": true,
|
||||
"socks5h": true,
|
||||
}
|
||||
@ -31,7 +34,7 @@ var allowedSchemes = map[string]bool{
|
||||
// - TrimSpace 后为空视为直连
|
||||
// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露)
|
||||
// - Host 为空返回 error(用 Redacted() 脱敏)
|
||||
// - Scheme 必须为 http/https/socks5/socks5h
|
||||
// - Scheme 必须为 http/socks5/socks5h(https 不支持,因 CONNECT 明文传输)
|
||||
// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏)
|
||||
func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
|
||||
trimmed = strings.TrimSpace(raw)
|
||||
@ -51,7 +54,10 @@ func Parse(raw string) (trimmed string, parsed *url.URL, err error) {
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
if !allowedSchemes[scheme] {
|
||||
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme)
|
||||
if scheme == "https" {
|
||||
return "", nil, fmt.Errorf("https proxy scheme is not supported: current implementation sends CONNECT in plaintext (use http:// or socks5:// instead)")
|
||||
}
|
||||
return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, socks5, socks5h)", scheme)
|
||||
}
|
||||
|
||||
// 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。
|
||||
|
||||
@ -47,13 +47,13 @@ func TestParse_有效HTTP代理(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse_有效HTTPS代理(t *testing.T) {
|
||||
_, parsed, err := Parse("https://proxy.example.com:443")
|
||||
if err != nil {
|
||||
t.Fatalf("有效 HTTPS 代理应成功: %v", err)
|
||||
func TestParse_HTTPS代理被拒绝(t *testing.T) {
|
||||
_, _, err := Parse("https://proxy.example.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("https 代理应返回错误(当前实现不支持 TLS-to-proxy)")
|
||||
}
|
||||
if parsed.Scheme != "https" {
|
||||
t.Errorf("Scheme 不匹配: got %q", parsed.Scheme)
|
||||
if !strings.Contains(err.Error(), "https proxy scheme is not supported") {
|
||||
t.Errorf("错误信息应包含 'https proxy scheme is not supported': got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
571
backend/internal/pkg/telemetry/telemetry.go
Normal file
571
backend/internal/pkg/telemetry/telemetry.go
Normal file
@ -0,0 +1,571 @@
|
||||
// Package telemetry simulates the real Claude Code CLI's OTEL telemetry events.
|
||||
//
|
||||
// Real CLI emits events to two channels:
|
||||
// 1. Anthropic event_logging/batch (first-party events)
|
||||
// 2. Datadog log intake (third-party observability)
|
||||
//
|
||||
// Ported from antigravity/node-tls-proxy/proxy.js — see that file for JS original.
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
)
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────
|
||||
|
||||
const (
|
||||
ddAPIKey = "pubbbf48e6d78dae54bceaa4acf463299bf"
|
||||
fakeNodeVersion = "v24.3.0"
|
||||
buildTime = "2026-03-31T01:39:46Z"
|
||||
sessionMaxAge = time.Hour
|
||||
sessionCleanup = 5 * time.Minute
|
||||
telemetryTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ─── Virtual Host Identity ───────────────────────────────
|
||||
|
||||
var (
|
||||
mbpNames = []string{"alex", "sam", "chris", "max", "lee", "kai", "jamie", "taylor", "morgan", "casey", "drew", "avery", "riley", "blake", "jordan", "ryan", "parker", "quinn", "reese", "cameron"}
|
||||
mbpSuffix = []string{"-MBP", "-MacBook", "-MacBook-Pro", "-MacBook-Air", "s-MBP", "s-MacBook", "s-MacBook-Pro"}
|
||||
)
|
||||
|
||||
type hostIdentity struct {
|
||||
Hostname string
|
||||
Username string
|
||||
Terminal string
|
||||
Shell string
|
||||
MachineID string
|
||||
Arch string
|
||||
OSVersion string
|
||||
KernelRelease string
|
||||
ExecPath string
|
||||
RipgrepVersion string
|
||||
RipgrepPath string
|
||||
McpServerCount int
|
||||
McpFailCount int
|
||||
}
|
||||
|
||||
func hashField(seed, field string) []byte {
|
||||
h := sha256.Sum256([]byte(seed + ":" + field))
|
||||
return h[:]
|
||||
}
|
||||
|
||||
func generateHostIdentity(seed string) hostIdentity {
|
||||
hb := hashField(seed, "hostname")
|
||||
name := mbpNames[int(hb[0])%len(mbpNames)]
|
||||
sfx := mbpSuffix[int(hb[1])%len(mbpSuffix)]
|
||||
|
||||
termRoll := int(hashField(seed, "terminal")[0]) % 100
|
||||
var terminal string
|
||||
switch {
|
||||
case termRoll < 75:
|
||||
terminal = "xterm-256color"
|
||||
case termRoll < 88:
|
||||
terminal = "screen-256color"
|
||||
case termRoll < 96:
|
||||
terminal = "alacritty"
|
||||
default:
|
||||
terminal = "kitty"
|
||||
}
|
||||
|
||||
shellRoll := int(hashField(seed, "shell")[0]) % 100
|
||||
var shell string
|
||||
switch {
|
||||
case shellRoll < 65:
|
||||
shell = "/bin/zsh"
|
||||
case shellRoll < 82:
|
||||
shell = "/usr/local/bin/zsh"
|
||||
case shellRoll < 93:
|
||||
shell = "/bin/bash"
|
||||
default:
|
||||
shell = "/opt/homebrew/bin/fish"
|
||||
}
|
||||
|
||||
mid := hashField(seed, "machine-id")
|
||||
machineID := fmt.Sprintf("%s-%s-%s-%s-%s",
|
||||
strings.ToUpper(hex.EncodeToString(mid[0:4])),
|
||||
strings.ToUpper(hex.EncodeToString(mid[4:6])),
|
||||
strings.ToUpper(hex.EncodeToString(mid[6:8])),
|
||||
strings.ToUpper(hex.EncodeToString(mid[8:10])),
|
||||
strings.ToUpper(hex.EncodeToString(mid[10:16])),
|
||||
)
|
||||
|
||||
osb := hashField(seed, "os")
|
||||
major := 13 + int(osb[0])%3
|
||||
minor := int(osb[1]) % 8
|
||||
patch := int(osb[2]) % 5
|
||||
darwinMajor := major + 9 // macOS 13 = Darwin 22
|
||||
darwinMinor := int(osb[3]) % 7
|
||||
darwinPatch := int(osb[4]) % 3
|
||||
|
||||
archRoll := int(hashField(seed, "arch")[0]) % 100
|
||||
arch := "arm64"
|
||||
if archRoll >= 70 {
|
||||
arch = "x64"
|
||||
}
|
||||
|
||||
execRoll := int(hashField(seed, "exec")[0]) % 100
|
||||
var execPath string
|
||||
switch {
|
||||
case execRoll < 40:
|
||||
execPath = "/usr/local/bin/claude"
|
||||
case execRoll < 70:
|
||||
execPath = "/opt/homebrew/bin/claude"
|
||||
case execRoll < 90:
|
||||
execPath = fmt.Sprintf("/Users/%s/.npm-global/bin/claude", name)
|
||||
default:
|
||||
execPath = fmt.Sprintf("/Users/%s/.local/bin/claude", name)
|
||||
}
|
||||
|
||||
rgVersions := []string{"14.1.1", "14.1.0", "14.0.3", "14.0.2", "13.0.0", "14.1.2", "14.0.1"}
|
||||
rgPaths := []string{"/opt/homebrew/bin/rg", "/usr/local/bin/rg", "/Users/" + name + "/.cargo/bin/rg", "/usr/bin/rg"}
|
||||
rb := hashField(seed, "ripgrep")
|
||||
|
||||
return hostIdentity{
|
||||
Hostname: name + sfx,
|
||||
Username: name,
|
||||
Terminal: terminal,
|
||||
Shell: shell,
|
||||
MachineID: machineID,
|
||||
Arch: arch,
|
||||
OSVersion: fmt.Sprintf("%d.%d.%d", major, minor, patch),
|
||||
KernelRelease: fmt.Sprintf("%d.%d.%d", darwinMajor, darwinMinor, darwinPatch),
|
||||
ExecPath: execPath,
|
||||
RipgrepVersion: rgVersions[int(rb[0])%len(rgVersions)],
|
||||
RipgrepPath: rgPaths[int(rb[1])%len(rgPaths)],
|
||||
McpServerCount: int(rb[2])%5 + 1,
|
||||
McpFailCount: int(rb[3]) % 3,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Session State ───────────────────────────────────────
|
||||
|
||||
type sessionState struct {
|
||||
SessionID string
|
||||
DeviceID string
|
||||
HostID hostIdentity
|
||||
StartTime time.Time
|
||||
RequestCount int64
|
||||
RipgrepReported bool
|
||||
}
|
||||
|
||||
var (
|
||||
sessions = make(map[string]*sessionState)
|
||||
sessionsMu sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(sessionCleanup)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
sessionsMu.Lock()
|
||||
for k, s := range sessions {
|
||||
if now.Sub(s.StartTime) > sessionMaxAge {
|
||||
delete(sessions, k)
|
||||
}
|
||||
}
|
||||
sessionsMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func generateDeviceID(accountSeed string) string {
|
||||
h := sha256.Sum256([]byte("device:" + accountSeed))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func getOrCreateSession(deviceID string) *sessionState {
|
||||
sessionsMu.Lock()
|
||||
defer sessionsMu.Unlock()
|
||||
|
||||
if s, ok := sessions[deviceID]; ok {
|
||||
return s
|
||||
}
|
||||
s := &sessionState{
|
||||
SessionID: generateUUID(),
|
||||
DeviceID: deviceID,
|
||||
HostID: generateHostIdentity(deviceID),
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
sessions[deviceID] = s
|
||||
return s
|
||||
}
|
||||
|
||||
func generateUUID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
b[6] = (b[6] & 0x0f) | 0x40 // version 4
|
||||
b[8] = (b[8] & 0x3f) | 0x80 // variant
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// ─── Process Metrics Simulation ──────────────────────────
|
||||
|
||||
func buildProcessMetrics(uptime float64) string {
|
||||
baseRss := 180_000_000.0 + math.Min(uptime*50_000, 200_000_000)
|
||||
rss := int64(baseRss + rand.Float64()*80_000_000)
|
||||
heapTotal := int64(float64(rss)*0.6 + rand.Float64()*10_000_000)
|
||||
heapUsed := int64(float64(heapTotal)*0.5 + rand.Float64()*float64(heapTotal)*0.3)
|
||||
|
||||
metrics := map[string]any{
|
||||
"uptime": uptime,
|
||||
"rss": rss,
|
||||
"heapTotal": heapTotal,
|
||||
"heapUsed": heapUsed,
|
||||
"external": 14_000_000 + rand.Intn(2_000_000),
|
||||
"arrayBuffers": rand.Intn(200_000),
|
||||
"constrainedMemory": 51539607552,
|
||||
"cpuUsage": map[string]int64{
|
||||
"user": int64(uptime*10_000 + rand.Float64()*300_000),
|
||||
"system": int64(uptime*2_000 + rand.Float64()*80_000),
|
||||
},
|
||||
"cpuPercent": rand.Float64() * 200,
|
||||
}
|
||||
data, _ := json.Marshal(metrics)
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// ─── Env Block ───────────────────────────────────────────
|
||||
|
||||
func buildEnvBlock(hostID hostIdentity) map[string]any {
|
||||
return map[string]any{
|
||||
"platform": "darwin",
|
||||
"node_version": fakeNodeVersion,
|
||||
"terminal": hostID.Terminal,
|
||||
"package_managers": "npm,pnpm",
|
||||
"runtimes": "deno,node",
|
||||
"is_running_with_bun": true,
|
||||
"is_ci": false,
|
||||
"is_claubbit": false,
|
||||
"is_github_action": false,
|
||||
"is_claude_code_action": false,
|
||||
"is_claude_ai_auth": false,
|
||||
"version": claude.DefaultCLIVersion,
|
||||
"arch": hostID.Arch,
|
||||
"is_claude_code_remote": false,
|
||||
"deployment_environment": "unknown-darwin",
|
||||
"is_conductor": false,
|
||||
"version_base": claude.DefaultCLIVersion,
|
||||
"build_time": buildTime,
|
||||
"is_local_agent_mode": false,
|
||||
"vcs": "git",
|
||||
"platform_raw": "darwin",
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Event Building ──────────────────────────────────────
|
||||
|
||||
type eventWrapper struct {
|
||||
EventType string `json:"event_type"`
|
||||
EventData map[string]any `json:"event_data"`
|
||||
}
|
||||
|
||||
func buildEvent(eventName string, session *sessionState, model, betas string, extraData map[string]any, tsOverride string) eventWrapper {
|
||||
uptime := time.Since(session.StartTime).Seconds()
|
||||
pm := buildProcessMetrics(uptime)
|
||||
|
||||
ts := tsOverride
|
||||
if ts == "" {
|
||||
ts = time.Now().UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
model = "claude-sonnet-4-6"
|
||||
}
|
||||
if betas == "" {
|
||||
betas = "claude-code-20250219,interleaved-thinking-2025-05-14"
|
||||
}
|
||||
|
||||
data := map[string]any{
|
||||
"event_name": eventName,
|
||||
"client_timestamp": ts,
|
||||
"model": model,
|
||||
"session_id": session.SessionID,
|
||||
"user_type": "external",
|
||||
"betas": betas,
|
||||
"env": buildEnvBlock(session.HostID),
|
||||
"entrypoint": "cli",
|
||||
"is_interactive": true,
|
||||
"client_type": "cli",
|
||||
"process": pm,
|
||||
"event_id": generateUUID(),
|
||||
"device_id": session.DeviceID,
|
||||
}
|
||||
|
||||
for k, v := range extraData {
|
||||
data[k] = v
|
||||
}
|
||||
|
||||
return eventWrapper{
|
||||
EventType: "ClaudeCodeInternalEvent",
|
||||
EventData: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Send Functions ──────────────────────────────────────
|
||||
|
||||
var httpClient = &http.Client{Timeout: telemetryTimeout}
|
||||
|
||||
func sendTelemetryEvents(events []eventWrapper, session *sessionState, authToken string) {
|
||||
if len(events) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{"events": events}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "https://api.anthropic.com/api/event_logging/batch", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "claude-code/"+claude.DefaultCLIVersion)
|
||||
req.Header.Set("x-service-name", "claude-code")
|
||||
if authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
slog.Debug("telemetry_error", "error", err.Error())
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
slog.Debug("telemetry_sent", "status", resp.StatusCode, "events", len(events))
|
||||
}
|
||||
|
||||
func sendDatadogLog(eventName string, session *sessionState, model string) {
|
||||
hostID := session.HostID
|
||||
uptime := time.Since(session.StartTime).Seconds()
|
||||
|
||||
if model == "" {
|
||||
model = "claude-sonnet-4-6"
|
||||
}
|
||||
|
||||
baseRss := 180_000_000.0 + math.Min(uptime*50_000, 200_000_000)
|
||||
rss := int64(baseRss + rand.Float64()*80_000_000)
|
||||
heapTotal := int64(float64(rss)*0.6 + rand.Float64()*10_000_000)
|
||||
heapUsed := int64(float64(heapTotal)*0.5 + rand.Float64()*float64(heapTotal)*0.3)
|
||||
|
||||
pm := map[string]any{
|
||||
"uptime": uptime,
|
||||
"rss": rss,
|
||||
"heapTotal": heapTotal,
|
||||
"heapUsed": heapUsed,
|
||||
"external": 14_000_000 + rand.Intn(2_000_000),
|
||||
"arrayBuffers": rand.Intn(10_000),
|
||||
"constrainedMemory": 0,
|
||||
"cpuUsage": map[string]int64{
|
||||
"user": int64(uptime*10_000 + rand.Float64()*300_000),
|
||||
"system": int64(uptime*2_000 + rand.Float64()*80_000),
|
||||
},
|
||||
}
|
||||
|
||||
entry := map[string]any{
|
||||
"ddsource": "nodejs",
|
||||
"ddtags": fmt.Sprintf("event:%s,arch:%s,client_type:cli,model:%s,platform:darwin,user_type:external,version:%s,version_base:%s", eventName, hostID.Arch, model, claude.DefaultCLIVersion, claude.DefaultCLIVersion),
|
||||
"message": eventName,
|
||||
"service": "claude-code",
|
||||
"hostname": "claude-code",
|
||||
"env": "external",
|
||||
"model": model,
|
||||
"session_id": session.SessionID,
|
||||
"user_type": "external",
|
||||
"entrypoint": "cli",
|
||||
"is_interactive": "true",
|
||||
"client_type": "cli",
|
||||
"process_metrics": pm,
|
||||
"platform": "darwin",
|
||||
"platform_raw": "darwin",
|
||||
"arch": hostID.Arch,
|
||||
"node_version": fakeNodeVersion,
|
||||
"version": claude.DefaultCLIVersion,
|
||||
"version_base": claude.DefaultCLIVersion,
|
||||
"build_time": buildTime,
|
||||
"deployment_environment": "unknown-darwin",
|
||||
"vcs": "git",
|
||||
}
|
||||
|
||||
body, err := json.Marshal([]any{entry})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "https://http-intake.logs.us5.datadoghq.com/api/v2/logs", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "axios/1.13.6")
|
||||
req.Header.Set("DD-API-KEY", ddAPIKey)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// ─── Public API ──────────────────────────────────────────
|
||||
|
||||
// EmitPreRequest fires pre-request telemetry events for a /v1/messages request.
|
||||
// accountSeed should be a stable identifier for the account (e.g. account ID or OAuth token suffix).
|
||||
// authHeader is the Authorization header value (used for device ID derivation).
|
||||
// authToken is the raw OAuth token (without "Bearer " prefix) for 1P auth.
|
||||
// model is the model name from the request body (e.g. "claude-sonnet-4-6").
|
||||
// betaHeader is the anthropic-beta header value.
|
||||
func EmitPreRequest(accountSeed, authHeader, authToken, model, betaHeader string) {
|
||||
authSuffix := authHeader
|
||||
if len(authSuffix) > 16 {
|
||||
authSuffix = authSuffix[len(authSuffix)-16:]
|
||||
}
|
||||
deviceID := generateDeviceID(accountSeed + ":" + authSuffix)
|
||||
session := getOrCreateSession(deviceID)
|
||||
session.RequestCount++
|
||||
|
||||
if model == "" {
|
||||
model = "claude-sonnet-4-6"
|
||||
}
|
||||
betas := betaHeader
|
||||
if betas == "" {
|
||||
betas = claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
// First request: full startup sequence
|
||||
if session.RequestCount == 1 {
|
||||
hostID := session.HostID
|
||||
baseTime := time.Now()
|
||||
ts := func(offsetMs int) string {
|
||||
return baseTime.Add(time.Duration(offsetMs) * time.Millisecond).UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
batch1 := []eventWrapper{
|
||||
buildEvent("tengu_started", session, model, betas, nil, ts(0)),
|
||||
buildEvent("tengu_init", session, model, betas, nil, ts(80+rand.Intn(120))),
|
||||
buildEvent("tengu_ripgrep_availability", session, model, betas, map[string]any{
|
||||
"ripgrep_available": true,
|
||||
"ripgrep_version": hostID.RipgrepVersion,
|
||||
"ripgrep_path": hostID.RipgrepPath,
|
||||
}, ts(200+rand.Intn(150))),
|
||||
}
|
||||
|
||||
// MCP connection events
|
||||
mcpOffset := 400
|
||||
mcpSuccessCount := hostID.McpServerCount - hostID.McpFailCount
|
||||
for i := 0; i < hostID.McpFailCount; i++ {
|
||||
mcpOffset += 100 + rand.Intn(300)
|
||||
batch1 = append(batch1, buildEvent("tengu_mcp_server_connection_failed", session, model, betas, nil, ts(mcpOffset)))
|
||||
}
|
||||
for i := 0; i < mcpSuccessCount; i++ {
|
||||
mcpOffset += 200 + rand.Intn(500)
|
||||
batch1 = append(batch1, buildEvent("tengu_mcp_server_connection_succeeded", session, model, betas, nil, ts(mcpOffset)))
|
||||
}
|
||||
|
||||
session.RipgrepReported = true
|
||||
go sendTelemetryEvents(batch1, session, authToken)
|
||||
go sendDatadogLog("tengu_started", session, model)
|
||||
go sendDatadogLog("tengu_init", session, model)
|
||||
|
||||
// Delayed batch (~25-35s later, matches real CLI timing)
|
||||
go func() {
|
||||
time.Sleep(time.Duration(25000+rand.Intn(10000)) * time.Millisecond)
|
||||
sendTelemetryEvents([]eventWrapper{
|
||||
buildEvent("tengu_session_init", session, model, betas, nil, ""),
|
||||
buildEvent("tengu_context_loaded", session, model, betas, nil, ""),
|
||||
}, session, authToken)
|
||||
}()
|
||||
}
|
||||
|
||||
// Every request: tengu_api_query (real CLI event name)
|
||||
go sendTelemetryEvents([]eventWrapper{
|
||||
buildEvent("tengu_api_query", session, model, betas, nil, ""),
|
||||
}, session, authToken)
|
||||
}
|
||||
|
||||
// EmitPostRequest fires post-request telemetry events after upstream response.
|
||||
func EmitPostRequest(accountSeed, authHeader, authToken, model, betaHeader string, statusCode int) {
|
||||
authSuffix := authHeader
|
||||
if len(authSuffix) > 16 {
|
||||
authSuffix = authSuffix[len(authSuffix)-16:]
|
||||
}
|
||||
deviceID := generateDeviceID(accountSeed + ":" + authSuffix)
|
||||
session := getOrCreateSession(deviceID)
|
||||
|
||||
if model == "" {
|
||||
model = "claude-sonnet-4-6"
|
||||
}
|
||||
betas := betaHeader
|
||||
if betas == "" {
|
||||
betas = claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
// Real CLI uses tengu_api_success on success, tengu_api_error on failure
|
||||
if statusCode < 400 {
|
||||
events := []eventWrapper{
|
||||
buildEvent("tengu_api_success", session, model, betas, nil, ""),
|
||||
}
|
||||
go sendTelemetryEvents(events, session, authToken)
|
||||
go sendDatadogLog("tengu_api_success", session, model)
|
||||
} else {
|
||||
var errMsg string
|
||||
switch {
|
||||
case statusCode == 429:
|
||||
errMsg = "rate_limit_exceeded"
|
||||
case statusCode == 529:
|
||||
errMsg = "overloaded"
|
||||
case statusCode >= 500:
|
||||
errMsg = "server_error"
|
||||
default:
|
||||
errMsg = "client_error"
|
||||
}
|
||||
errEvent := buildEvent("tengu_api_error", session, model, betas, map[string]any{
|
||||
"error_type": "TelemetrySafeError",
|
||||
"error_code": statusCode,
|
||||
"error_message": errMsg,
|
||||
}, "")
|
||||
go sendTelemetryEvents([]eventWrapper{errEvent}, session, authToken)
|
||||
go sendDatadogLog("tengu_api_error", session, model)
|
||||
}
|
||||
|
||||
// Random tool_use event (30% probability, 2-7s delay)
|
||||
if rand.Float64() < 0.3 {
|
||||
go func() {
|
||||
time.Sleep(time.Duration(2000+rand.Intn(5000)) * time.Millisecond)
|
||||
sendTelemetryEvents([]eventWrapper{
|
||||
buildEvent("tengu_tool_use_success", session, model, betas, nil, ""),
|
||||
}, session, authToken)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Jitter returns a random delay to inject before forwarding a request.
|
||||
// 80% fast (80-300ms exponential), 20% slow (400-1200ms uniform).
|
||||
func Jitter() time.Duration {
|
||||
if rand.Float64() < 0.80 {
|
||||
ms := 80.0 + (-math.Log(rand.Float64()) * 90.0)
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
ms := 400.0 + rand.Float64()*800.0
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
@ -141,7 +141,7 @@ func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDiale
|
||||
// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
|
||||
// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
|
||||
func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
slog.Info("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
|
||||
|
||||
// Step 1: Create SOCKS5 dialer
|
||||
var auth *proxy.Auth
|
||||
@ -160,20 +160,24 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
|
||||
}
|
||||
|
||||
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
|
||||
// Use a TCP-only forward dialer (no DNS resolution) so the SOCKS5 protocol
|
||||
// sends the target hostname to the proxy for remote DNS resolution (socks5h semantics).
|
||||
// proxy.Direct would attempt local DNS first, which fails inside Docker.
|
||||
tcpDialer := &net.Dialer{}
|
||||
socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, tcpDialer)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
|
||||
slog.Info("tls_fingerprint_socks5_dialer_failed", "error", err)
|
||||
return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Establish SOCKS5 tunnel to target
|
||||
slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
slog.Info("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
|
||||
conn, err := socksDialer.(proxy.ContextDialer).DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
|
||||
slog.Info("tls_fingerprint_socks5_connect_failed", "error", err)
|
||||
return nil, fmt.Errorf("SOCKS5 connect: %w", err)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_socks5_tunnel_established")
|
||||
slog.Info("tls_fingerprint_socks5_tunnel_established", "target", addr)
|
||||
|
||||
// Step 3: Perform TLS handshake on the tunnel with utls fingerprint
|
||||
return performTLSHandshake(ctx, conn, d.profile, addr)
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
|
||||
@ -267,18 +269,42 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) (*req.Client, error) {
|
||||
// 禁用 CookieJar,确保每次授权都是干净的会话
|
||||
// Use Node.js 24.x TLS fingerprint (same as API requests) instead of Chrome
|
||||
// to ensure TLS fingerprint consistency across the entire token lifecycle.
|
||||
// Previously used ImpersonateChrome() which created a JA3 mismatch between
|
||||
// OAuth token exchange/refresh and API calls.
|
||||
profile := &tlsfingerprint.Profile{
|
||||
Name: "oauth-nodejs24",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second).
|
||||
ImpersonateChrome().
|
||||
SetCookieJar(nil) // 禁用 CookieJar
|
||||
SetCookieJar(nil). // 禁用 CookieJar,确保每次授权都是干净的会话
|
||||
EnableForceHTTP1() // 强制 HTTP/1.1,避免 H2 升级与自定义 TLS dialer 冲突
|
||||
|
||||
trimmed, _, err := proxyurl.Parse(proxyURL)
|
||||
trimmed, parsedProxy, err := proxyurl.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if trimmed != "" {
|
||||
client.SetProxyURL(trimmed)
|
||||
|
||||
if trimmed != "" && parsedProxy != nil {
|
||||
scheme := strings.ToLower(parsedProxy.Scheme)
|
||||
slog.Info("oauth_create_client", "proxy_scheme", scheme, "proxy_host", parsedProxy.Hostname())
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, parsedProxy)
|
||||
client.SetDialTLS(socks5Dialer.DialTLSContext)
|
||||
case "http", "https":
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, parsedProxy)
|
||||
client.SetDialTLS(httpDialer.DialTLSContext)
|
||||
default:
|
||||
client.SetProxyURL(trimmed)
|
||||
}
|
||||
} else {
|
||||
slog.Info("oauth_create_client", "proxy_scheme", "none", "raw_proxy_url", proxyURL)
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
client.SetDialTLS(dialer.DialTLSContext)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
|
||||
@ -13,6 +13,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@ -22,6 +24,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
@ -43,7 +46,7 @@ const (
|
||||
defaultIdleConnTimeout = 90 * time.Second
|
||||
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
|
||||
// LLM 请求可能排队较久,需要较长超时
|
||||
defaultResponseHeaderTimeout = 300 * time.Second
|
||||
defaultResponseHeaderTimeout = 600 * time.Second
|
||||
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
|
||||
// 超出后会淘汰最久未使用的客户端
|
||||
defaultMaxUpstreamClients = 5000
|
||||
@ -99,16 +102,34 @@ type httpUpstreamService struct {
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
//
|
||||
// 当环境变量 ANTIGRAVITY_LS_MODE=true 时,自动包装 LS 池拦截层:
|
||||
// - 仅对已知兼容的 LS 请求形态启用转发
|
||||
// - 对普通 streamGenerateContent 请求保留原有直连路径,避免误送到不兼容的 LS RPC
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置,包含连接池参数和隔离策略
|
||||
//
|
||||
// 返回:
|
||||
// - service.HTTPUpstream 接口实现
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
return &httpUpstreamService{
|
||||
base := &httpUpstreamService{
|
||||
cfg: cfg,
|
||||
clients: make(map[string]*upstreamClientEntry),
|
||||
}
|
||||
|
||||
// LS 池模式: 包装一层拦截, streamGenerateContent 走 LS
|
||||
if lspool.IsLSModeEnabled() {
|
||||
pool := lspool.GlobalPool(cfg)
|
||||
if pool != nil {
|
||||
slog.Info("LS pool mode enabled — streamGenerateContent will route through Language Server",
|
||||
"component", "http_upstream")
|
||||
return lspool.NewLSPoolUpstream(pool, base)
|
||||
}
|
||||
slog.Warn("LS pool mode enabled but pool is nil — falling back to direct mode",
|
||||
"component", "http_upstream")
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// Do 执行 HTTP 请求
|
||||
@ -128,6 +149,14 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
||||
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
// TLS 指纹路由:对匹配主机使用 Go 原生 utls 指纹
|
||||
// 使用 utls 模拟 Claude CLI 的 JA3/JA4 指纹,支持直连和代理
|
||||
if s.isTLSFingerprintRoutingEnabled() && req != nil && req.URL != nil && req.URL.Scheme == "https" {
|
||||
if s.shouldRouteWithTLSFingerprint(req) {
|
||||
return s.doWithTLSFingerprint(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -175,7 +204,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
proxyInfo = logredact.RedactProxyURL(proxyURL)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo, "profile", profile.Name)
|
||||
|
||||
@ -272,7 +301,7 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", logredact.RedactProxyURL(proxyKey))
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
|
||||
85
backend/internal/repository/http_upstream_antigravity.go
Normal file
85
backend/internal/repository/http_upstream_antigravity.go
Normal file
@ -0,0 +1,85 @@
|
||||
package repository
|
||||
|
||||
// ==============================================================
|
||||
// antigravity — Go 原生 TLS 指纹扩展
|
||||
//
|
||||
// 此文件包含 Antigravity fork 新增的 TLS 指纹代理功能,
|
||||
// 与 upstream 代码完全隔离,便于 upstream 更新时的合并维护。
|
||||
//
|
||||
// 上游文件 http_upstream.go 中的钩子调用点:
|
||||
// Do() — 匹配主机时路由到 doWithTLSFingerprint
|
||||
// DoWithTLS() — profile==nil 时回退到 Do(),触发同样的路由
|
||||
//
|
||||
// 替代原先的 Node.js TLS 代理(node-tls-proxy),
|
||||
// 直接使用 Go utls 库模拟 Claude CLI 的 TLS 指纹。
|
||||
// ==============================================================
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
// isTLSFingerprintRoutingEnabled 检查 TLS 指纹路由是否启用
|
||||
// 复用 NodeTLSProxy.Enabled 配置项,保持配置兼容
|
||||
func (s *httpUpstreamService) isTLSFingerprintRoutingEnabled() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return s.cfg.Gateway.NodeTLSProxy.Enabled
|
||||
}
|
||||
|
||||
// shouldRouteWithTLSFingerprint 判断请求是否应该使用 TLS 指纹
|
||||
// 仅拦截目标主机在 proxy_hosts 白名单中的 HTTPS 请求,
|
||||
// 白名单为空时默认只代理 api.anthropic.com。
|
||||
func (s *httpUpstreamService) shouldRouteWithTLSFingerprint(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil || req.URL.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
reqHost := req.URL.Hostname()
|
||||
if reqHost == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
hosts := s.cfg.Gateway.NodeTLSProxy.ProxyHosts
|
||||
if len(hosts) == 0 {
|
||||
return reqHost == "api.anthropic.com"
|
||||
}
|
||||
for _, h := range hosts {
|
||||
if reqHost == h {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// defaultTLSProfile 返回模拟 Claude CLI (Node.js 24.x) 的默认 TLS 指纹配置
|
||||
// 所有 slice 字段留空 → dialer.go 自动使用内置的 Node.js 24.x 默认值
|
||||
// ALPN 仅声明 http/1.1,与真实 CLI 行为一致(undici allowH2=false)
|
||||
func defaultTLSProfile() *tlsfingerprint.Profile {
|
||||
return &tlsfingerprint.Profile{
|
||||
Name: "claude_cli_builtin",
|
||||
EnableGREASE: true,
|
||||
}
|
||||
}
|
||||
|
||||
// doWithTLSFingerprint 使用 Go 原生 utls TLS 指纹发送请求
|
||||
// 直接通过 DoWithTLS 路径,利用已有的 utls dialer 基础设施:
|
||||
// - 直连:Dialer (TCP → utls handshake)
|
||||
// - HTTP 代理:HTTPProxyDialer (CONNECT 隧道 → utls handshake)
|
||||
// - SOCKS5 代理:SOCKS5ProxyDialer (SOCKS5 隧道 → utls handshake)
|
||||
func (s *httpUpstreamService) doWithTLSFingerprint(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = logredact.RedactProxyURL(proxyURL)
|
||||
}
|
||||
slog.Debug("tls_fingerprint_routing",
|
||||
"account_id", accountID,
|
||||
"target", req.URL.Host,
|
||||
"proxy", proxyInfo,
|
||||
)
|
||||
|
||||
return s.DoWithTLS(req, proxyURL, accountID, accountConcurrency, defaultTLSProfile())
|
||||
}
|
||||
@ -53,8 +53,9 @@ const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||
|
||||
type migrationChecksumCompatibilityRule struct {
|
||||
fileChecksum string
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
fileChecksum string
|
||||
acceptedFileChecksums map[string]struct{}
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
}
|
||||
|
||||
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
|
||||
@ -73,6 +74,15 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
|
||||
},
|
||||
},
|
||||
"082_create_gateway_debug_logs.sql": {
|
||||
fileChecksum: "b740d7274afbd37d4448e3a3a9aa1fb562181ded5d0319e47a6444187d22f6b1",
|
||||
acceptedFileChecksums: map[string]struct{}{
|
||||
"bf5348a22cf1f27c852096beb3583b67ec43819af82b2f9664397a5638e5b386": {},
|
||||
},
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"d00c2e69711cc0c006b0234566101d8639ba08db77283558f07e2ba412ec177d": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
@ -328,7 +338,9 @@ func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
|
||||
return false
|
||||
}
|
||||
if rule.fileChecksum != fileChecksum {
|
||||
return false
|
||||
if _, ok := rule.acceptedFileChecksums[fileChecksum]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
_, ok = rule.acceptedDBChecksum[dbChecksum]
|
||||
return ok
|
||||
|
||||
@ -92,6 +92,11 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
||||
}
|
||||
require.NotEmpty(t, accepted)
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
|
||||
|
||||
for alternateFileChecksum := range rule.acceptedFileChecksums {
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, alternateFileChecksum))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAtlasBaselineAligned(t *testing.T) {
|
||||
|
||||
@ -104,6 +104,10 @@ func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
if strings.Contains(lowerBody, "exhausted your capacity on this model") &&
|
||||
strings.Contains(lowerBody, "quota will reset after") {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
|
||||
@ -21,6 +21,16 @@ func TestClassifyAntigravity429(t *testing.T) {
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("模型配额耗尽文案也视为可切 AI Credits", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}`)
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("结构化限流", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
@ -146,6 +156,68 @@ func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t
|
||||
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_ModelQuotaMessage_UsesCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 151,
|
||||
Name: "acc-151",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"message": "You have exhausted your capacity on this model. Your quota will reset after 1h59m40s."
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-opus-4-6",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
@ -112,6 +113,144 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// injectLSPoolHeaders adds internal headers carrying OAuth credentials for the
|
||||
// LS pool layer. These headers are consumed and stripped by LSPoolUpstream
|
||||
// before the request reaches the Language Server. When LS mode is disabled,
|
||||
// these headers are harmless — the direct upstream never sees them because
|
||||
// they are stripped inside LSPoolUpstream.Do(). In direct mode the request
|
||||
// goes straight through httpUpstreamService.Do() which doesn't inspect them.
|
||||
func injectLSPoolHeaders(req *http.Request, account *Account) {
|
||||
if req == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if rt, ok := account.Credentials["refresh_token"].(string); ok && rt != "" {
|
||||
req.Header.Set("X-Antigravity-Refresh-Token", rt)
|
||||
}
|
||||
if ea, ok := account.Credentials["expires_at"].(string); ok && ea != "" {
|
||||
req.Header.Set("X-Antigravity-Token-Expiry", ea)
|
||||
}
|
||||
req.Header.Set("X-Antigravity-Use-AI-Credits", strconv.FormatBool(account.IsOveragesEnabled()))
|
||||
|
||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
||||
if availableCredits != nil {
|
||||
req.Header.Set("X-Antigravity-Available-Credits", strconv.FormatInt(int64(*availableCredits), 10))
|
||||
}
|
||||
if minimumCreditAmount != nil {
|
||||
req.Header.Set("X-Antigravity-Minimum-Credit-Amount", strconv.FormatInt(int64(*minimumCreditAmount), 10))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveLSPoolModelCreditsState(account *Account) (*int32, *int32) {
|
||||
if account == nil || account.Extra == nil {
|
||||
minimum := int32(50)
|
||||
return nil, &minimum
|
||||
}
|
||||
|
||||
var availableCredits *int32
|
||||
var minimumCreditAmount *int32
|
||||
|
||||
collect := func(entry map[string]any) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
if !isGoogleOneAICreditsEntry(entry) {
|
||||
return
|
||||
}
|
||||
if availableCredits == nil {
|
||||
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "Amount", "amount", "creditAmount")); ok {
|
||||
availableCredits = &parsed
|
||||
}
|
||||
}
|
||||
if minimumCreditAmount == nil {
|
||||
if parsed, ok := parseAICreditsInt32(firstPresent(entry, "MinimumBalance", "minimum_balance", "minimumCreditAmountForUsage")); ok {
|
||||
minimumCreditAmount = &parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawCredits, ok := account.Extra["ai_credits"].([]any); ok {
|
||||
for _, item := range rawCredits {
|
||||
if entry, ok := item.(map[string]any); ok {
|
||||
collect(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if loadCodeAssist, ok := account.Extra["load_code_assist"].(map[string]any); ok {
|
||||
if paidTier, ok := loadCodeAssist["paidTier"].(map[string]any); ok {
|
||||
if credits, ok := paidTier["availableCredits"].([]any); ok {
|
||||
for _, item := range credits {
|
||||
if entry, ok := item.(map[string]any); ok {
|
||||
collect(entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if minimumCreditAmount == nil {
|
||||
defaultMinimum := int32(50)
|
||||
minimumCreditAmount = &defaultMinimum
|
||||
}
|
||||
return availableCredits, minimumCreditAmount
|
||||
}
|
||||
|
||||
func isGoogleOneAICreditsEntry(entry map[string]any) bool {
|
||||
creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string)
|
||||
creditType = strings.TrimSpace(strings.ToUpper(creditType))
|
||||
return creditType == "" || creditType == "GOOGLE_ONE_AI"
|
||||
}
|
||||
|
||||
func firstPresent(entry map[string]any, keys ...string) any {
|
||||
for _, key := range keys {
|
||||
if value, ok := entry[key]; ok {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAICreditsInt32(raw any) (int32, bool) {
|
||||
switch v := raw.(type) {
|
||||
case int:
|
||||
return int32(v), true
|
||||
case int32:
|
||||
return v, true
|
||||
case int64:
|
||||
return int32(v), true
|
||||
case float32:
|
||||
return int32(v), true
|
||||
case float64:
|
||||
return int32(v), true
|
||||
case json.Number:
|
||||
parsed, err := v.Int64()
|
||||
if err != nil {
|
||||
floatVal, floatErr := strconv.ParseFloat(v.String(), 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
}
|
||||
return int32(parsed), true
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsed, err := strconv.ParseInt(trimmed, 10, 32)
|
||||
if err == nil {
|
||||
return int32(parsed), true
|
||||
}
|
||||
floatVal, floatErr := strconv.ParseFloat(trimmed, 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
@ -305,6 +444,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
}
|
||||
|
||||
injectLSPoolHeaders(retryReq, p.account)
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
|
||||
@ -489,6 +629,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err)
|
||||
break
|
||||
}
|
||||
injectLSPoolHeaders(retryReq, p.account)
|
||||
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
@ -627,6 +768,7 @@ urlFallbackLoop:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
injectLSPoolHeaders(upstreamReq, p.account)
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if p.c != nil && len(p.body) > 0 {
|
||||
@ -1289,9 +1431,19 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) {
|
||||
sessionID := ""
|
||||
if len(preferredSessionID) > 0 {
|
||||
sessionID = preferredSessionID[0]
|
||||
}
|
||||
|
||||
bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("补全 sessionId 失败: %w", err)
|
||||
}
|
||||
|
||||
var request any
|
||||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||||
if err := json.Unmarshal(bodyWithSessionID, &request); err != nil {
|
||||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||||
}
|
||||
|
||||
@ -2156,7 +2308,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID)
|
||||
if err != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
|
||||
}
|
||||
@ -2220,10 +2372,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
injectLSPoolHeaders(fallbackReq, account)
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
@ -2263,7 +2416,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||
|
||||
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID)
|
||||
if wrapErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
@ -3031,6 +3184,53 @@ func handleStreamReadError(err error, clientDisconnected bool, prefix string) (d
|
||||
return false, false
|
||||
}
|
||||
|
||||
func googleStatusTextForHTTP(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "INVALID_ARGUMENT"
|
||||
case http.StatusNotFound:
|
||||
return "NOT_FOUND"
|
||||
case http.StatusTooManyRequests:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "UNAVAILABLE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func buildAnthropicStreamErrorEvent(errType, message string) string {
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return "event: error\ndata: " + string(data) + "\n\n"
|
||||
}
|
||||
|
||||
func buildGeminiStreamErrorEvent(status int, message string) string {
|
||||
payload := map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleStatusTextForHTTP(status),
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return "event: error\ndata: " + string(data) + "\n\n"
|
||||
}
|
||||
|
||||
func lsQuotaExhaustedMessage(err error) string {
|
||||
msg := strings.TrimSpace(lspool.LSQuotaExhaustedMessage(err))
|
||||
if msg != "" {
|
||||
return sanitizeUpstreamErrorMessage(msg)
|
||||
}
|
||||
return "You have exhausted your capacity on this model."
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@ -3126,12 +3326,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
sendErrorEvent := func(status int, message string) {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
_, _ = fmt.Fprint(c.Writer, buildGeminiStreamErrorEvent(status, message))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@ -3145,12 +3345,18 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled {
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if lspool.IsLSQuotaExhaustedError(ev.err) {
|
||||
msg := lsQuotaExhaustedMessage(ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "LS quota exhausted during streaming (antigravity gemini): %s", msg)
|
||||
sendErrorEvent(http.StatusTooManyRequests, msg)
|
||||
return nil, ev.err
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
sendErrorEvent(http.StatusBadGateway, "Response too large")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream read failed")
|
||||
return nil, ev.err
|
||||
}
|
||||
|
||||
@ -3213,7 +3419,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
@ -3973,12 +4179,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
sendErrorEvent := func(errType, message string) {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
_, _ = fmt.Fprint(c.Writer, buildAnthropicStreamErrorEvent(errType, message))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@ -4012,12 +4218,18 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if lspool.IsLSQuotaExhaustedError(ev.err) {
|
||||
msg := lsQuotaExhaustedMessage(ev.err)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "LS quota exhausted during streaming (antigravity claude): %s", msg)
|
||||
sendErrorEvent("rate_limit_error", msg)
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
sendErrorEvent("api_error", "Response too large")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
sendErrorEvent("api_error", "Upstream stream read failed")
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
@ -4043,7 +4255,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("stream_timeout")
|
||||
sendErrorEvent("api_error", "Upstream stream timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
|
||||
@ -600,6 +600,63 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
req.Header.Set("session_id", "session-header-1")
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-session-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Name: "acc-gemini-session",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
|
||||
var wrapped map[string]any
|
||||
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
|
||||
requestNode, ok := wrapped["request"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "session-header-1", requestNode["sessionId"])
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
|
||||
@ -103,8 +103,15 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
defer cancel()
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(refreshCtx, account, p.executor, antigravityTokenRefreshSkew)
|
||||
if err != nil {
|
||||
// 标记账号临时不可调度,避免后续请求继续命中
|
||||
p.markTempUnschedulable(account, err)
|
||||
// 全局 OAuth 配置缺失不应污染账号状态;账号级失败才标记 temp unschedulable。
|
||||
if shouldMarkTempUnschedulableForRefreshError(err) {
|
||||
p.markTempUnschedulable(account, err)
|
||||
} else {
|
||||
slog.Warn("antigravity_token_provider.temp_unschedulable_skipped",
|
||||
"account_id", account.ID,
|
||||
"reason", err.Error(),
|
||||
)
|
||||
}
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
@ -226,6 +233,23 @@ func (p *AntigravityTokenProvider) markTempUnschedulable(account *Account, refre
|
||||
}
|
||||
}
|
||||
|
||||
func shouldMarkTempUnschedulableForRefreshError(refreshErr error) bool {
|
||||
if refreshErr == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(refreshErr.Error()))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(msg, "antigravity_oauth_client_secret_missing") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(msg, "missing antigravity oauth client_secret") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldMarkTempUnschedulableForRefreshError(t *testing.T) {
|
||||
t.Run("skip global oauth client secret missing", func(t *testing.T) {
|
||||
err := errors.New(`token 刷新失败 (重试后): error: code=400 reason="ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING" message="missing antigravity oauth client_secret; set ANTIGRAVITY_OAUTH_CLIENT_SECRET" metadata=map[]`)
|
||||
require.False(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
|
||||
t.Run("allow account specific refresh error", func(t *testing.T) {
|
||||
err := errors.New("token 刷新失败 (重试后): invalid_grant")
|
||||
require.True(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
}
|
||||
258
backend/internal/service/bootstrap_preflight.go
Normal file
258
backend/internal/service/bootstrap_preflight.go
Normal file
@ -0,0 +1,258 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// backgroundSimulator simulates the real Claude Code CLI's background network behavior.
|
||||
// Real CLI performs bootstrap, GrowthBook feature-flag polling, and policy_limits polling.
|
||||
// Missing these creates a behavioral correlation gap detectable by Anthropic.
|
||||
type backgroundSimulator struct {
|
||||
mu sync.Mutex
|
||||
called map[int64]*accountBackgroundState
|
||||
client *http.Client
|
||||
baseURL string
|
||||
}
|
||||
|
||||
type accountBackgroundState struct {
|
||||
bootstrapAt time.Time
|
||||
growthbookAt time.Time
|
||||
policyLimitsAt time.Time
|
||||
// Timers for periodic polling — stopped when account goes idle
|
||||
growthbookTimer *time.Timer
|
||||
policyLimitsTimer *time.Timer
|
||||
exitTimer *time.Timer // fires tengu_exit after idle timeout
|
||||
accessToken string
|
||||
accountID int64
|
||||
}
|
||||
|
||||
const (
|
||||
bootstrapCooldown = 1 * time.Hour
|
||||
growthbookInterval = 20 * time.Minute
|
||||
policyLimitsInterval = 1 * time.Hour
|
||||
sessionIdleTimeout = 10 * time.Minute // fire tengu_exit after no requests for 10min
|
||||
)
|
||||
|
||||
var globalBgSim = &backgroundSimulator{
|
||||
called: make(map[int64]*accountBackgroundState),
|
||||
client: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
// SetBootstrapBaseURL configures the API base URL for background simulation calls.
|
||||
func SetBootstrapBaseURL(baseURL string) {
|
||||
globalBgSim.baseURL = baseURL
|
||||
}
|
||||
|
||||
// TriggerBootstrapIfNeeded fires background simulation calls for the given OAuth account.
|
||||
// On first call per account: bootstrap + GrowthBook + policy_limits + start periodic timers.
|
||||
// On subsequent calls: refresh idle timer (delays tengu_exit).
|
||||
func TriggerBootstrapIfNeeded(accountID int64, accessToken string) {
|
||||
bg := globalBgSim
|
||||
|
||||
bg.mu.Lock()
|
||||
state, exists := bg.called[accountID]
|
||||
|
||||
if !exists {
|
||||
// First time: create state, fire all startup calls
|
||||
state = &accountBackgroundState{
|
||||
accessToken: accessToken,
|
||||
accountID: accountID,
|
||||
}
|
||||
bg.called[accountID] = state
|
||||
bg.mu.Unlock()
|
||||
|
||||
// Fire-and-forget startup sequence (matches real CLI order)
|
||||
go bg.doBootstrap(state)
|
||||
go bg.doGrowthBookFetch(state)
|
||||
go bg.doPolicyLimitsFetch(state)
|
||||
bg.startPeriodicPolling(state)
|
||||
bg.resetExitTimer(state)
|
||||
return
|
||||
}
|
||||
|
||||
// Update token (may have been refreshed)
|
||||
state.accessToken = accessToken
|
||||
|
||||
// Bootstrap: 1 hour cooldown
|
||||
if time.Since(state.bootstrapAt) >= bootstrapCooldown {
|
||||
state.bootstrapAt = time.Now()
|
||||
bg.mu.Unlock()
|
||||
go bg.doBootstrap(state)
|
||||
} else {
|
||||
bg.mu.Unlock()
|
||||
}
|
||||
|
||||
// Reset idle timer (user is active)
|
||||
bg.resetExitTimer(state)
|
||||
}
|
||||
|
||||
func (bg *backgroundSimulator) getBaseURL() string {
|
||||
if bg.baseURL != "" {
|
||||
return bg.baseURL
|
||||
}
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
|
||||
// ─── Bootstrap ───────────────────────────────────────────
|
||||
|
||||
func (bg *backgroundSimulator) doBootstrap(state *accountBackgroundState) {
|
||||
state.bootstrapAt = time.Now()
|
||||
endpoint := bg.getBaseURL() + "/api/claude_cli/bootstrap"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Source: extracted/src/services/api/bootstrap.ts:85-91
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("claude-code/%s", claude.DefaultCLIVersion))
|
||||
req.Header.Set("Authorization", "Bearer "+state.accessToken)
|
||||
req.Header.Set("anthropic-beta", claude.BetaOAuth)
|
||||
|
||||
resp, err := bg.client.Do(req)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.bootstrap", "Bootstrap preflight failed: %v", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
logger.LegacyPrintf("service.bootstrap", "Bootstrap completed: account=%d status=%d", state.accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
// ─── GrowthBook Feature Flags ────────────────────────────
|
||||
|
||||
func (bg *backgroundSimulator) doGrowthBookFetch(state *accountBackgroundState) {
|
||||
state.growthbookAt = time.Now()
|
||||
|
||||
// Real CLI uses GrowthBook SDK with remoteEval: true
|
||||
// SDK key for external users: sdk-zAZezfDKGoZuXXKe
|
||||
// Endpoint: GET {apiHost}/sub/features/{clientKey}
|
||||
// Source: extracted/src/services/analytics/growthbook.ts:503-555
|
||||
endpoint := bg.getBaseURL() + "/sub/features/sdk-zAZezfDKGoZuXXKe"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+state.accessToken)
|
||||
req.Header.Set("anthropic-beta", claude.BetaOAuth)
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("claude-code/%s", claude.DefaultCLIVersion))
|
||||
|
||||
resp, err := bg.client.Do(req)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.bootstrap", "GrowthBook fetch failed: %v", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
logger.LegacyPrintf("service.bootstrap", "GrowthBook fetch completed: account=%d status=%d", state.accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
// ─── Policy Limits ───────────────────────────────────────
|
||||
|
||||
func (bg *backgroundSimulator) doPolicyLimitsFetch(state *accountBackgroundState) {
|
||||
state.policyLimitsAt = time.Now()
|
||||
|
||||
// Source: extracted/src/services/policyLimits/index.ts:127
|
||||
endpoint := bg.getBaseURL() + "/api/claude_code/policy_limits"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("claude-code/%s", claude.DefaultCLIVersion))
|
||||
req.Header.Set("Authorization", "Bearer "+state.accessToken)
|
||||
req.Header.Set("anthropic-beta", claude.BetaOAuth)
|
||||
|
||||
resp, err := bg.client.Do(req)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.bootstrap", "Policy limits fetch failed: %v", err)
|
||||
return
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
logger.LegacyPrintf("service.bootstrap", "Policy limits fetch completed: account=%d status=%d", state.accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
// ─── Periodic Polling ────────────────────────────────────
|
||||
|
||||
func (bg *backgroundSimulator) startPeriodicPolling(state *accountBackgroundState) {
|
||||
// GrowthBook: every 20 minutes
|
||||
// Source: growthbook.ts setupPeriodicGrowthBookRefresh()
|
||||
go func() {
|
||||
// Add jitter to avoid all accounts polling at the same time
|
||||
jitter := time.Duration(state.accountID%300) * time.Second
|
||||
time.Sleep(growthbookInterval + jitter)
|
||||
|
||||
for {
|
||||
bg.doGrowthBookFetch(state)
|
||||
time.Sleep(growthbookInterval + time.Duration(state.accountID%60)*time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
// Policy limits: every hour
|
||||
// Source: policyLimits/index.ts refreshPolicyLimits()
|
||||
go func() {
|
||||
jitter := time.Duration(state.accountID%600) * time.Second
|
||||
time.Sleep(policyLimitsInterval + jitter)
|
||||
|
||||
for {
|
||||
bg.doPolicyLimitsFetch(state)
|
||||
time.Sleep(policyLimitsInterval + time.Duration(state.accountID%120)*time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ─── tengu_exit Event ────────────────────────────────────
|
||||
|
||||
func (bg *backgroundSimulator) resetExitTimer(state *accountBackgroundState) {
|
||||
bg.mu.Lock()
|
||||
defer bg.mu.Unlock()
|
||||
|
||||
// Cancel existing timer
|
||||
if state.exitTimer != nil {
|
||||
state.exitTimer.Stop()
|
||||
}
|
||||
|
||||
// Set new timer: fire tengu_exit after idle timeout
|
||||
state.exitTimer = time.AfterFunc(sessionIdleTimeout, func() {
|
||||
bg.fireExitEvent(state)
|
||||
})
|
||||
}
|
||||
|
||||
func (bg *backgroundSimulator) fireExitEvent(state *accountBackgroundState) {
|
||||
// tengu_exit is sent via the 1P event_logging/batch endpoint
|
||||
// Source: extracted/src/services/analytics/firstPartyEventLogger.ts
|
||||
// We use proxy.js's sendTelemetryEvents path (same endpoint), but since
|
||||
// proxy.js runs per-request and this is idle-based, we fire directly here.
|
||||
|
||||
// The event is a lightweight signal — just needs to exist in Anthropic's logs.
|
||||
// Real CLI sends it on process exit; we simulate on idle timeout.
|
||||
logger.LegacyPrintf("service.bootstrap", "Session idle timeout, would fire tengu_exit: account=%d", state.accountID)
|
||||
|
||||
// Clean up the state to allow fresh bootstrap on next request
|
||||
bg.mu.Lock()
|
||||
delete(bg.called, state.accountID)
|
||||
bg.mu.Unlock()
|
||||
}
|
||||
182
backend/internal/service/gateway_attribution.go
Normal file
182
backend/internal/service/gateway_attribution.go
Normal file
@ -0,0 +1,182 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// Attribution block constants matching real Claude Code 2.1.88.
|
||||
// Source: src/constants/system.ts + src/utils/fingerprint.ts
|
||||
const (
|
||||
// fingerprintSalt must match the hardcoded salt in the real CLI.
|
||||
// Source: extracted/src/utils/fingerprint.ts:8
|
||||
fingerprintSalt = "59cf53e54c78"
|
||||
)
|
||||
|
||||
// computeAttributionFingerprint computes a 3-character hex fingerprint
|
||||
// matching the algorithm in the real Claude Code CLI.
|
||||
//
|
||||
// Algorithm: SHA256(SALT + msg[4] + msg[7] + msg[20] + version)[:3]
|
||||
// Source: extracted/src/utils/fingerprint.ts:50-63
|
||||
func computeAttributionFingerprint(firstUserMessageText, cliVersion string) string {
|
||||
indices := [3]int{4, 7, 20}
|
||||
chars := make([]byte, 0, 3)
|
||||
for _, i := range indices {
|
||||
if i < len(firstUserMessageText) {
|
||||
chars = append(chars, firstUserMessageText[i])
|
||||
} else {
|
||||
chars = append(chars, '0')
|
||||
}
|
||||
}
|
||||
|
||||
input := fmt.Sprintf("%s%s%s", fingerprintSalt, string(chars), cliVersion)
|
||||
hash := sha256.Sum256([]byte(input))
|
||||
return hex.EncodeToString(hash[:])[:3]
|
||||
}
|
||||
|
||||
// extractFirstUserMessageText extracts text from the first user message in the body.
|
||||
// Handles both string content and array content (text blocks).
|
||||
func extractFirstUserMessageText(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var firstText string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if msg.Get("role").String() != "user" {
|
||||
return true // continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
firstText = content.String()
|
||||
return false // break
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
if block.Get("type").String() == "text" {
|
||||
firstText = block.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return firstText
|
||||
}
|
||||
|
||||
// buildAttributionBlock builds the x-anthropic-billing-header attribution string
|
||||
// that real Claude Code injects as the first system text block.
|
||||
//
|
||||
// Format: x-anthropic-billing-header: cc_version=<VERSION>.<fingerprint>; cc_entrypoint=cli; cch=00000;
|
||||
// Source: extracted/src/constants/system.ts:73-95
|
||||
func buildAttributionBlock(cliVersion, fingerprint string) string {
|
||||
version := cliVersion + "." + fingerprint
|
||||
// 注意:cch 字段由 Bun 的 NATIVE_CLIENT_ATTESTATION 编译时 feature 控制。
|
||||
// npm 安装版本(非原生二进制)此 feature 为 false,所以不包含 cch 字段。
|
||||
// 只有原生二进制安装(Bun 打包)才会有 cch,且其值会被 Bun 的 Zig 层替换为真实 hash。
|
||||
// 我们模拟 npm 安装版本的行为:不包含 cch。
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=%s; cc_entrypoint=cli;", version)
|
||||
}
|
||||
|
||||
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
|
||||
// as the very first system text block in the request body.
|
||||
// This must come BEFORE the "You are Claude Code" block.
|
||||
//
|
||||
// The real CLI injects this as system[0] with cache_control: {type: "ephemeral"}.
|
||||
func injectAttributionBlock(body []byte, cliVersion string) []byte {
|
||||
// Compute fingerprint from the first user message
|
||||
firstMsgText := extractFirstUserMessageText(body)
|
||||
fingerprint := computeAttributionFingerprint(firstMsgText, cliVersion)
|
||||
attribution := buildAttributionBlock(cliVersion, fingerprint)
|
||||
|
||||
// Build the attribution text block as JSON
|
||||
attrBlock, err := marshalAnthropicSystemTextBlock(attribution, true)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to build attribution block: %v", err)
|
||||
return body
|
||||
}
|
||||
|
||||
systemResult := gjson.GetBytes(body, "system")
|
||||
|
||||
// Handle the different system formats
|
||||
switch {
|
||||
case !systemResult.Exists() || systemResult.Type == gjson.Null:
|
||||
// No system field — inject just the attribution block
|
||||
newBody, err := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock}))
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
case systemResult.Type == gjson.String:
|
||||
// String system — convert to array: [attribution, original]
|
||||
origBlock, err := marshalAnthropicSystemTextBlock(systemResult.String(), false)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock, origBlock}))
|
||||
if setErr != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
case systemResult.IsArray():
|
||||
// Array system — check if attribution already exists, prepend if not
|
||||
var items [][]byte
|
||||
alreadyHasAttribution := false
|
||||
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("type").String() == "text" {
|
||||
text := item.Get("text").String()
|
||||
if len(text) > 30 && text[:30] == "x-anthropic-billing-header: cc" {
|
||||
alreadyHasAttribution = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if alreadyHasAttribution {
|
||||
return body
|
||||
}
|
||||
|
||||
items = append(items, attrBlock)
|
||||
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||
items = append(items, []byte(item.Raw))
|
||||
return true
|
||||
})
|
||||
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw(items))
|
||||
if setErr != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
default:
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionIDForAccount generates a deterministic per-account session UUID
|
||||
// that remains stable within a process-like timeframe.
|
||||
// Uses instanceSalt + accountID to ensure uniqueness across sub2api instances.
|
||||
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
||||
// Use a per-account stable UUID (like real CLI's per-process UUID).
|
||||
// We use accountID as the base — each account gets a different "session".
|
||||
seed := fmt.Sprintf("session:%s:%d", instanceSalt, accountID)
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
sessionUUID, err := uuid.FromBytes(hash[:16])
|
||||
if err != nil {
|
||||
return uuid.New().String()
|
||||
}
|
||||
// Set UUID v4 variant
|
||||
sessionUUID[6] = (sessionUUID[6] & 0x0f) | 0x40
|
||||
sessionUUID[8] = (sessionUUID[8] & 0x3f) | 0x80
|
||||
return sessionUUID.String()
|
||||
}
|
||||
@ -41,9 +41,10 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
|
||||
resultStr := string(result)
|
||||
|
||||
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||
require.NotContains(t, resultStr, `"temperature"`)
|
||||
require.NotContains(t, resultStr, `"tool_choice"`)
|
||||
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"tool_choice"`, `"omega"`, `"tools"`, `"metadata"`)
|
||||
// temperature 和 tool_choice 不再剥离,透传客户端原始值(与真实 CLI 行为一致)
|
||||
require.Contains(t, resultStr, `"temperature"`)
|
||||
require.Contains(t, resultStr, `"tool_choice"`)
|
||||
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
|
||||
require.Contains(t, resultStr, `"tools":[]`)
|
||||
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
|
||||
|
||||
255
backend/internal/service/gateway_debug_logger.go
Normal file
255
backend/internal/service/gateway_debug_logger.go
Normal file
@ -0,0 +1,255 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GatewayDebugLogEntry holds all fields for a single debug log row.
|
||||
type GatewayDebugLogEntry struct {
|
||||
UpstreamRequestID string
|
||||
AccountID int64
|
||||
AccountEmail string
|
||||
AccountPlatform string
|
||||
|
||||
EventType string // "api_call", "oauth_refresh", "error"
|
||||
|
||||
Method string
|
||||
FullURL string
|
||||
RequestHeaders map[string]string
|
||||
RequestBody []byte // raw bytes, stored as TEXT
|
||||
RequestSize int
|
||||
|
||||
ResponseStatus int
|
||||
ResponseHeaders map[string]string
|
||||
ResponseBodyPreview string
|
||||
ResponseSize int
|
||||
|
||||
ModelRequested string
|
||||
ModelUpstream string
|
||||
IsStream bool
|
||||
DurationMs int
|
||||
|
||||
TLSProfile string
|
||||
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
// GatewayDebugLogger writes debug log entries to gateway_debug_logs.
|
||||
type GatewayDebugLogger struct {
|
||||
db *sql.DB
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewGatewayDebugLogger creates a new debug logger (enabled by default).
|
||||
func NewGatewayDebugLogger(db *sql.DB) *GatewayDebugLogger {
|
||||
l := &GatewayDebugLogger{db: db}
|
||||
l.enabled.Store(true)
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *GatewayDebugLogger) IsEnabled() bool {
|
||||
return l != nil && l.enabled.Load()
|
||||
}
|
||||
|
||||
// DB returns the underlying database handle (for admin queries).
|
||||
func (l *GatewayDebugLogger) DB() *sql.DB {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return l.db
|
||||
}
|
||||
|
||||
func (l *GatewayDebugLogger) Enable() {
|
||||
if l != nil {
|
||||
l.enabled.Store(true)
|
||||
slog.Info("gateway debug logging ENABLED")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *GatewayDebugLogger) Disable() {
|
||||
if l != nil {
|
||||
l.enabled.Store(false)
|
||||
slog.Info("gateway debug logging DISABLED")
|
||||
}
|
||||
}
|
||||
|
||||
const insertDebugLogSQL = `
|
||||
INSERT INTO gateway_debug_logs (
|
||||
upstream_request_id, account_id, account_email, account_platform,
|
||||
event_type,
|
||||
method, full_url, request_headers, request_body, request_size,
|
||||
response_status, response_headers, response_body_preview, response_size,
|
||||
model_requested, model_upstream, is_stream, duration_ms,
|
||||
tls_profile, error_message
|
||||
) VALUES (
|
||||
$1, $2, $3, $4,
|
||||
$5,
|
||||
$6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14,
|
||||
$15, $16, $17, $18,
|
||||
$19, $20
|
||||
)`
|
||||
|
||||
// Log writes a debug log entry asynchronously (fire-and-forget).
|
||||
func (l *GatewayDebugLogger) Log(entry GatewayDebugLogEntry) {
|
||||
if !l.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := l.db.ExecContext(ctx, insertDebugLogSQL,
|
||||
nullStr(entry.UpstreamRequestID),
|
||||
entry.AccountID,
|
||||
nullStr(entry.AccountEmail),
|
||||
nullStr(entry.AccountPlatform),
|
||||
coalesce(entry.EventType, "api_call"),
|
||||
nullStr(entry.Method),
|
||||
nullStr(entry.FullURL),
|
||||
mapToString(entry.RequestHeaders),
|
||||
bytesToString(entry.RequestBody),
|
||||
entry.RequestSize,
|
||||
entry.ResponseStatus,
|
||||
mapToString(entry.ResponseHeaders),
|
||||
nullStr(entry.ResponseBodyPreview),
|
||||
entry.ResponseSize,
|
||||
nullStr(entry.ModelRequested),
|
||||
nullStr(entry.ModelUpstream),
|
||||
entry.IsStream,
|
||||
entry.DurationMs,
|
||||
nullStr(entry.TLSProfile),
|
||||
nullStr(entry.ErrorMessage),
|
||||
)
|
||||
if err != nil {
|
||||
slog.Warn("gateway debug log write failed", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// LogUpstreamRequest captures request+response from a gateway forward call.
|
||||
func (l *GatewayDebugLogger) LogUpstreamRequest(
|
||||
account *Account,
|
||||
upstreamReq *http.Request,
|
||||
upstreamBody []byte,
|
||||
resp *http.Response,
|
||||
responsePreview string,
|
||||
responseSize int,
|
||||
originalModel string,
|
||||
upstreamModel string,
|
||||
isStream bool,
|
||||
duration time.Duration,
|
||||
tlsProfile string,
|
||||
errMsg string,
|
||||
) {
|
||||
if !l.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
entry := GatewayDebugLogEntry{
|
||||
AccountID: account.ID,
|
||||
AccountEmail: account.Name,
|
||||
AccountPlatform: account.Platform,
|
||||
EventType: "api_call",
|
||||
Method: upstreamReq.Method,
|
||||
FullURL: upstreamReq.URL.String(),
|
||||
RequestHeaders: extractHeaders(upstreamReq.Header),
|
||||
RequestBody: upstreamBody,
|
||||
RequestSize: len(upstreamBody),
|
||||
ModelRequested: originalModel,
|
||||
ModelUpstream: upstreamModel,
|
||||
IsStream: isStream,
|
||||
DurationMs: int(duration.Milliseconds()),
|
||||
TLSProfile: tlsProfile,
|
||||
ErrorMessage: errMsg,
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
entry.UpstreamRequestID = resp.Header.Get("x-request-id")
|
||||
entry.ResponseStatus = resp.StatusCode
|
||||
entry.ResponseHeaders = extractHeaders(resp.Header)
|
||||
entry.ResponseBodyPreview = debugTruncate(responsePreview, 4096)
|
||||
entry.ResponseSize = responseSize
|
||||
}
|
||||
|
||||
l.Log(entry)
|
||||
}
|
||||
|
||||
// LogOAuthRefresh logs an OAuth token refresh event.
|
||||
func (l *GatewayDebugLogger) LogOAuthRefresh(accountID int64, accountEmail string, duration time.Duration, errMsg string) {
|
||||
if !l.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
l.Log(GatewayDebugLogEntry{
|
||||
AccountID: accountID,
|
||||
AccountEmail: accountEmail,
|
||||
EventType: "oauth_refresh",
|
||||
DurationMs: int(duration.Milliseconds()),
|
||||
ErrorMessage: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func extractHeaders(h http.Header) map[string]string {
|
||||
out := make(map[string]string, len(h))
|
||||
for k, vals := range h {
|
||||
lower := strings.ToLower(k)
|
||||
if lower == "authorization" || lower == "x-api-key" {
|
||||
out[k] = "[REDACTED]"
|
||||
continue
|
||||
}
|
||||
out[k] = strings.Join(vals, ", ")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func debugTruncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
func nullStr(s string) interface{} {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// bytesToString converts raw bytes to string for TEXT column. No validation.
|
||||
func bytesToString(data []byte) interface{} {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// mapToString serializes a map to JSON string for TEXT column.
|
||||
func mapToString(m map[string]string) interface{} {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func coalesce(s, fallback string) string {
|
||||
if s == "" {
|
||||
return fallback
|
||||
}
|
||||
return s
|
||||
}
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/telemetry"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
@ -1085,18 +1086,11 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "temperature").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
|
||||
out = next
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
// 注意:不再剥离 temperature 和 tool_choice。
|
||||
// 真实 CLI 在 thinking 关闭时发 temperature:1,透传 tool_choice。
|
||||
// 之前无条件剥离会导致:
|
||||
// 1. temperature=0 的确定性请求被静默忽略
|
||||
// 2. tool_choice 强制工具调用被静默变成 auto 模式
|
||||
|
||||
if !modified {
|
||||
return body, modelID
|
||||
@ -4182,6 +4176,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// 注入 x-anthropic-billing-header attribution block(所有 OAuth 账号)
|
||||
// 真实 CLI 在 system prompt 的第一个 text block 注入此 billing header。
|
||||
// 用于 Anthropic 后端路由和验证。
|
||||
if account.IsOAuth() && !strings.Contains(strings.ToLower(reqModel), "haiku") {
|
||||
// 获取 CLI 版本:优先用指纹中的版本,回退到默认
|
||||
attrCLIVersion := claude.DefaultCLIVersion
|
||||
if fp := getHeaderRaw(c.Request.Header, "User-Agent"); fp != "" {
|
||||
if v := ExtractCLIVersion(fp); v != "" {
|
||||
attrCLIVersion = v
|
||||
}
|
||||
}
|
||||
body = injectAttributionBlock(body, attrCLIVersion)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
@ -4216,6 +4224,20 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Bootstrap 预热:模拟真实 CLI 启动时的 GET /api/claude_cli/bootstrap 调用
|
||||
// 真实 CLI 在首次 messages 请求前 fire-and-forget 调用此端点。
|
||||
if tokenType == "oauth" && token != "" {
|
||||
TriggerBootstrapIfNeeded(account.ID, token)
|
||||
// OTEL telemetry: emit pre-request events (tengu_started, tengu_api_query etc.)
|
||||
go telemetry.EmitPreRequest(
|
||||
fmt.Sprintf("%d", account.ID),
|
||||
token,
|
||||
token,
|
||||
reqModel,
|
||||
getHeaderRaw(c.Request.Header, "anthropic-beta"),
|
||||
)
|
||||
}
|
||||
|
||||
// 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
@ -4631,6 +4653,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 处理正常响应
|
||||
|
||||
// OTEL telemetry: emit post-request events (fire-and-forget)
|
||||
if tokenType == "oauth" && token != "" {
|
||||
go telemetry.EmitPostRequest(
|
||||
fmt.Sprintf("%d", account.ID),
|
||||
token,
|
||||
token,
|
||||
reqModel,
|
||||
getHeaderRaw(c.Request.Header, "anthropic-beta"),
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
// 触发上游接受回调(提前释放串行锁,不等流完成)
|
||||
if parsed.OnUpstreamAccepted != nil {
|
||||
parsed.OnUpstreamAccepted()
|
||||
@ -5821,13 +5855,37 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
// X-Claude-Code-Session-Id 头处理:
|
||||
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id
|
||||
// 2. 客户端未提供(mimic 模式)→ 生成确定性 per-account session UUID
|
||||
// 真实 CLI 每个请求都携带此 header(per-process UUID)。
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
} else if tokenType == "oauth" {
|
||||
// mimic 模式:生成 session-id
|
||||
var sessionID string
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
sessionID = parsed.SessionID
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
salt := ""
|
||||
if s.cfg != nil {
|
||||
salt = s.cfg.Gateway.InstanceSalt
|
||||
}
|
||||
sessionID = generateSessionIDForAccount(salt, account.ID)
|
||||
}
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
|
||||
}
|
||||
|
||||
// x-client-request-id: 真实 CLI 每个请求生成新 UUID(仅 1P)。
|
||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
|
||||
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
|
||||
@ -8549,13 +8607,33 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
// X-Claude-Code-Session-Id 头处理(count_tokens 路径)
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
} else if tokenType == "oauth" {
|
||||
var sessionID string
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
sessionID = parsed.SessionID
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
salt := ""
|
||||
if s.cfg != nil {
|
||||
salt = s.cfg.Gateway.InstanceSalt
|
||||
}
|
||||
sessionID = generateSessionIDForAccount(salt, account.ID)
|
||||
}
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
|
||||
}
|
||||
|
||||
// x-client-request-id(count_tokens 路径)
|
||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
|
||||
if c != nil && tokenType == "oauth" {
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -463,7 +464,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", logredact.RedactProxyURL(proxyURL))
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
|
||||
@ -26,13 +26,13 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.88 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.70.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessPackageVersion: "0.74.0",
|
||||
StainlessOS: "MacOS",
|
||||
StainlessArch: "arm64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v24.13.0",
|
||||
StainlessRuntimeVersion: "v24.3.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
@ -63,7 +63,8 @@ type IdentityCache interface {
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
cache IdentityCache
|
||||
cache IdentityCache
|
||||
instanceSalt string // 实例级隔离盐值,不同 sub2api 实例产生不同的 hash 输出
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
@ -242,8 +243,14 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
|
||||
sessionTail := parsed.SessionID // 原始session UUID
|
||||
|
||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
// 生成新的session hash: SHA256(salt::accountID::sessionTail) -> UUID格式
|
||||
// instanceSalt 使不同 sub2api 实例对相同输入产生不同的 hash
|
||||
var seed string
|
||||
if s.instanceSalt != "" {
|
||||
seed = fmt.Sprintf("%s::%d::%s", s.instanceSalt, accountID, sessionTail)
|
||||
} else {
|
||||
seed = fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
}
|
||||
newSessionHash := generateUUIDFromSeed(seed)
|
||||
|
||||
// 根据客户端版本选择输出格式
|
||||
|
||||
39
backend/internal/service/identity_service_antigravity.go
Normal file
39
backend/internal/service/identity_service_antigravity.go
Normal file
@ -0,0 +1,39 @@
|
||||
package service
|
||||
|
||||
// ==============================================================
|
||||
// antigravity — identity_service 扩展
|
||||
//
|
||||
// 此文件包含 Antigravity fork 对 IdentityService 的扩展,
|
||||
// 新增了实例级隔离盐值和指纹默认值覆盖功能。
|
||||
//
|
||||
// 对上游文件 identity_service.go 的最小化改动:
|
||||
// - defaultFingerprint 版本号更新
|
||||
// - IdentityService struct 新增 instanceSalt 字段
|
||||
// ==============================================================
|
||||
|
||||
// ApplyDefaultFingerprintOverrides 用配置覆盖 identity_service 的默认指纹
|
||||
// 允许不同部署实例设置不同的 CLI/SDK 版本号,避免所有实例指纹相同
|
||||
func ApplyDefaultFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) {
|
||||
if cliVersion != "" {
|
||||
defaultFingerprint.UserAgent = "claude-cli/" + cliVersion + " (external, cli)"
|
||||
}
|
||||
if pkgVersion != "" {
|
||||
defaultFingerprint.StainlessPackageVersion = pkgVersion
|
||||
}
|
||||
if runtimeVersion != "" {
|
||||
defaultFingerprint.StainlessRuntimeVersion = runtimeVersion
|
||||
}
|
||||
if os_ != "" {
|
||||
defaultFingerprint.StainlessOS = os_
|
||||
}
|
||||
if arch != "" {
|
||||
defaultFingerprint.StainlessArch = arch
|
||||
}
|
||||
}
|
||||
|
||||
// NewIdentityServiceWithSalt 创建带实例盐值的 IdentityService
|
||||
// 实例盐值用于 user_id 重写时的 session hash 混淆,
|
||||
// 使不同 sub2api 实例对相同输入产生不同的 hash 输出,增加隔离性
|
||||
func NewIdentityServiceWithSalt(cache IdentityCache, salt string) *IdentityService {
|
||||
return &IdentityService{cache: cache, instanceSalt: salt}
|
||||
}
|
||||
225
backend/internal/service/lspool_bootstrap_service.go
Normal file
225
backend/internal/service/lspool_bootstrap_service.go
Normal file
@ -0,0 +1,225 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLSPoolBootstrapConcurrency = 4
|
||||
)
|
||||
|
||||
type lsBootstrapAccountReader interface {
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
}
|
||||
|
||||
// LSPoolBootstrapService pre-creates LS workers for eligible Antigravity accounts on startup.
|
||||
type LSPoolBootstrapService struct {
|
||||
accountReader lsBootstrapAccountReader
|
||||
backend lspool.Backend
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
once sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewLSPoolBootstrapService(accountReader lsBootstrapAccountReader, backend lspool.Backend, cfg *config.Config) *LSPoolBootstrapService {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &LSPoolBootstrapService{
|
||||
accountReader: accountReader,
|
||||
backend: backend,
|
||||
cfg: cfg,
|
||||
logger: slog.Default().With("component", "service.lspool_bootstrap"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideLSPoolBootstrapService creates and starts the LS pool bootstrap worker.
|
||||
func ProvideLSPoolBootstrapService(accountRepo AccountRepository, cfg *config.Config) *LSPoolBootstrapService {
|
||||
svc := NewLSPoolBootstrapService(accountRepo, lspool.GlobalPool(cfg), cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.once.Do(func() {
|
||||
if s.backend == nil {
|
||||
if lspool.IsLSModeEnabled() {
|
||||
s.logger.Warn("startup bootstrap skipped: ls backend unavailable")
|
||||
}
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.bootstrap(s.ctx)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.cancel()
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrap(ctx context.Context) {
|
||||
if s.backend == nil || s.accountReader == nil {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := s.accountReader.ListByPlatform(ctx, PlatformAntigravity)
|
||||
if err != nil {
|
||||
s.logger.Warn("load antigravity accounts for ls bootstrap failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
candidates := make([]Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
if shouldBootstrapLSPoolAccount(&accounts[i], now) {
|
||||
candidates = append(candidates, accounts[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
s.logger.Info("startup bootstrap skipped: no eligible antigravity accounts")
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("starting ls worker bootstrap",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"concurrency", s.bootstrapConcurrency())
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
started int
|
||||
failed int
|
||||
)
|
||||
sem := make(chan struct{}, s.bootstrapConcurrency())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
loop:
|
||||
for i := range candidates {
|
||||
account := candidates[i]
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break loop
|
||||
case sem <- struct{}{}:
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(account Account) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.bootstrapAccount(&account); err != nil {
|
||||
mu.Lock()
|
||||
failed++
|
||||
mu.Unlock()
|
||||
s.logger.Warn("bootstrap ls worker failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
started++
|
||||
mu.Unlock()
|
||||
s.logger.Info("bootstrap ls worker ready", "account_id", account.ID)
|
||||
}(account)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
s.logger.Info("ls worker bootstrap completed",
|
||||
"accounts_total", len(accounts),
|
||||
"accounts_eligible", len(candidates),
|
||||
"workers_ready", started,
|
||||
"workers_failed", failed,
|
||||
"canceled", ctx.Err() != nil)
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapAccount(account *Account) error {
|
||||
if s.backend == nil {
|
||||
return fmt.Errorf("ls backend unavailable")
|
||||
}
|
||||
if account == nil {
|
||||
return fmt.Errorf("account is nil")
|
||||
}
|
||||
|
||||
accountKey := strconv.FormatInt(account.ID, 10)
|
||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("missing access token")
|
||||
}
|
||||
refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if ts := account.GetCredentialAsTime("expires_at"); ts != nil {
|
||||
expiresAt = ts.UTC()
|
||||
}
|
||||
|
||||
s.backend.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt)
|
||||
availableCredits, minimumCreditAmount := resolveLSPoolModelCreditsState(account)
|
||||
s.backend.SetAccountModelCredits(accountKey, account.IsOveragesEnabled(), availableCredits, minimumCreditAmount)
|
||||
|
||||
proxyURL := ""
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if _, err := s.backend.GetOrCreate(accountKey, "", proxyURL); err != nil {
|
||||
return fmt.Errorf("get or create ls worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LSPoolBootstrapService) bootstrapConcurrency() int {
|
||||
parallelism := defaultLSPoolBootstrapConcurrency
|
||||
if s.cfg != nil && s.cfg.Gateway.AntigravityLSWorker.MaxActive > 0 && s.cfg.Gateway.AntigravityLSWorker.MaxActive < parallelism {
|
||||
parallelism = s.cfg.Gateway.AntigravityLSWorker.MaxActive
|
||||
}
|
||||
if parallelism < 1 {
|
||||
return 1
|
||||
}
|
||||
return parallelism
|
||||
}
|
||||
|
||||
func shouldBootstrapLSPoolAccount(account *Account, now time.Time) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
if account.Status != StatusActive || !account.Schedulable {
|
||||
return false
|
||||
}
|
||||
if account.AutoPauseOnExpired && account.ExpiresAt != nil && !now.Before(*account.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(account.GetCredential("access_token")) == "" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(account.GetCredential("project_id")) != ""
|
||||
}
|
||||
262
backend/internal/service/lspool_bootstrap_service_test.go
Normal file
262
backend/internal/service/lspool_bootstrap_service_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/lspool"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeLSBootstrapAccountReader struct {
|
||||
mu sync.Mutex
|
||||
accounts []Account
|
||||
err error
|
||||
platforms []string
|
||||
}
|
||||
|
||||
func (f *fakeLSBootstrapAccountReader) ListByPlatform(_ context.Context, platform string) ([]Account, error) {
|
||||
f.mu.Lock()
|
||||
f.platforms = append(f.platforms, platform)
|
||||
accounts := append([]Account(nil), f.accounts...)
|
||||
err := f.err
|
||||
f.mu.Unlock()
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
type fakeLSPoolBackend struct {
|
||||
mu sync.Mutex
|
||||
tokenCalls map[string]fakeLSPoolTokenCall
|
||||
creditCalls map[string]fakeLSPoolCreditCall
|
||||
getCalls []fakeLSPoolGetCall
|
||||
getErrs map[string]error
|
||||
}
|
||||
|
||||
type fakeLSPoolTokenCall struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type fakeLSPoolCreditCall struct {
|
||||
UseAICredits bool
|
||||
AvailableCredits *int32
|
||||
MinimumCreditAmount *int32
|
||||
}
|
||||
|
||||
type fakeLSPoolGetCall struct {
|
||||
AccountID string
|
||||
RoutingKey string
|
||||
ProxyURL string
|
||||
}
|
||||
|
||||
func newFakeLSPoolBackend() *fakeLSPoolBackend {
|
||||
return &fakeLSPoolBackend{
|
||||
tokenCalls: make(map[string]fakeLSPoolTokenCall),
|
||||
creditCalls: make(map[string]fakeLSPoolCreditCall),
|
||||
getErrs: make(map[string]error),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) GetOrCreate(accountID, routingKey string, proxyURL ...string) (*lspool.Instance, error) {
|
||||
rawProxy := ""
|
||||
if len(proxyURL) > 0 {
|
||||
rawProxy = proxyURL[0]
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.getCalls = append(f.getCalls, fakeLSPoolGetCall{
|
||||
AccountID: accountID,
|
||||
RoutingKey: routingKey,
|
||||
ProxyURL: rawProxy,
|
||||
})
|
||||
if err := f.getErrs[accountID]; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &lspool.Instance{AccountID: accountID}, nil
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountToken(accountID, accessToken, refreshToken string, expiresAt time.Time) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.tokenCalls[accountID] = fakeLSPoolTokenCall{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) SetAccountModelCredits(accountID string, useAICredits bool, availableCredits, minimumCreditAmountForUsage *int32) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.creditCalls[accountID] = fakeLSPoolCreditCall{
|
||||
UseAICredits: useAICredits,
|
||||
AvailableCredits: copyInt32Ptr(availableCredits),
|
||||
MinimumCreditAmount: copyInt32Ptr(minimumCreditAmountForUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeLSPoolBackend) Stats() map[string]any { return nil }
|
||||
|
||||
func (f *fakeLSPoolBackend) Close() {}
|
||||
|
||||
func copyInt32Ptr(v *int32) *int32 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *v
|
||||
return &cp
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapEligibleAccounts(t *testing.T) {
|
||||
expiresAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||
expiredAt := time.Now().Add(-2 * time.Hour)
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 101,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-101",
|
||||
"refresh_token": "refresh-101",
|
||||
"expires_at": expiresAt.Format(time.RFC3339),
|
||||
"project_id": "proj-101",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"ai_credits": []any{
|
||||
map[string]any{
|
||||
"credit_type": "GOOGLE_ONE_AI",
|
||||
"amount": 120,
|
||||
"minimum_balance": 55,
|
||||
},
|
||||
},
|
||||
},
|
||||
Proxy: &Proxy{
|
||||
Protocol: "socks5h",
|
||||
Host: "127.0.0.1",
|
||||
Port: 1080,
|
||||
Username: "alice",
|
||||
Password: "secret",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 102,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: false,
|
||||
Credentials: map[string]any{"access_token": "token-102", "project_id": "proj-102"},
|
||||
},
|
||||
{
|
||||
ID: 103,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-103"},
|
||||
},
|
||||
{
|
||||
ID: 104,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
AutoPauseOnExpired: true,
|
||||
ExpiresAt: &expiredAt,
|
||||
Credentials: map[string]any{"access_token": "token-104", "project_id": "proj-104"},
|
||||
},
|
||||
{
|
||||
ID: 106,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-106", "project_id": "proj-106"},
|
||||
},
|
||||
{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-105"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
AntigravityLSWorker: config.GatewayAntigravityLSWorkerConfig{MaxActive: 3},
|
||||
},
|
||||
})
|
||||
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Equal(t, []string{PlatformAntigravity}, reader.platforms)
|
||||
|
||||
require.Len(t, backend.getCalls, 1)
|
||||
require.Equal(t, fakeLSPoolGetCall{
|
||||
AccountID: "101",
|
||||
RoutingKey: "",
|
||||
ProxyURL: "socks5h://alice:secret@127.0.0.1:1080",
|
||||
}, backend.getCalls[0])
|
||||
|
||||
tokenCall, ok := backend.tokenCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "token-101", tokenCall.AccessToken)
|
||||
require.Equal(t, "refresh-101", tokenCall.RefreshToken)
|
||||
require.Equal(t, expiresAt, tokenCall.ExpiresAt)
|
||||
|
||||
creditCall, ok := backend.creditCalls["101"]
|
||||
require.True(t, ok)
|
||||
require.True(t, creditCall.UseAICredits)
|
||||
require.NotNil(t, creditCall.AvailableCredits)
|
||||
require.Equal(t, int32(120), *creditCall.AvailableCredits)
|
||||
require.NotNil(t, creditCall.MinimumCreditAmount)
|
||||
require.Equal(t, int32(55), *creditCall.MinimumCreditAmount)
|
||||
|
||||
require.NotContains(t, backend.tokenCalls, "102")
|
||||
require.NotContains(t, backend.tokenCalls, "103")
|
||||
require.NotContains(t, backend.tokenCalls, "104")
|
||||
require.NotContains(t, backend.tokenCalls, "106")
|
||||
}
|
||||
|
||||
func TestLSPoolBootstrapServiceBootstrapContinuesOnWorkerFailure(t *testing.T) {
|
||||
reader := &fakeLSBootstrapAccountReader{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 201,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-201", "project_id": "proj-201"},
|
||||
},
|
||||
{
|
||||
ID: 202,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"access_token": "token-202", "project_id": "proj-202"},
|
||||
},
|
||||
},
|
||||
}
|
||||
backend := newFakeLSPoolBackend()
|
||||
backend.getErrs["201"] = errors.New("create failed")
|
||||
|
||||
svc := NewLSPoolBootstrapService(reader, backend, &config.Config{})
|
||||
svc.bootstrap(context.Background())
|
||||
|
||||
require.Len(t, backend.getCalls, 2)
|
||||
require.Contains(t, backend.tokenCalls, "201")
|
||||
require.Contains(t, backend.tokenCalls, "202")
|
||||
}
|
||||
@ -471,6 +471,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
ProvideLSPoolBootstrapService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideSubscriptionExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
|
||||
@ -2,6 +2,7 @@ package logredact
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -230,3 +231,19 @@ func isSensitiveKey(key string, keys map[string]struct{}) bool {
|
||||
func normalizeKey(key string) string {
|
||||
return strings.ToLower(strings.TrimSpace(key))
|
||||
}
|
||||
|
||||
// RedactProxyURL strips userinfo (username:password) from a proxy URL string
|
||||
// for safe logging. Returns the input unchanged if it's not a valid URL.
|
||||
func RedactProxyURL(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "<redacted-proxy-url>"
|
||||
}
|
||||
if parsed.User != nil {
|
||||
parsed.User = nil
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
@ -38,6 +38,34 @@ func TestRedactText_GOCSPX(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_StripsUserinfo(t *testing.T) {
|
||||
in := "http://user:pass@proxy.example.com:8080"
|
||||
out := RedactProxyURL(in)
|
||||
if out != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("expected userinfo stripped, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_EmptyString(t *testing.T) {
|
||||
if got := RedactProxyURL(""); got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_NoUserinfo(t *testing.T) {
|
||||
in := "socks5h://proxy.example.com:1080"
|
||||
out := RedactProxyURL(in)
|
||||
if out != in {
|
||||
t.Fatalf("expected unchanged URL, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL_InvalidURL(t *testing.T) {
|
||||
if got := RedactProxyURL("://invalid"); got != "<redacted-proxy-url>" {
|
||||
t.Fatalf("unexpected invalid URL redaction result: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) {
|
||||
clearExtraTextPatternCache()
|
||||
|
||||
|
||||
37
backend/migrations/082_create_gateway_debug_logs.sql
Normal file
37
backend/migrations/082_create_gateway_debug_logs.sql
Normal file
@ -0,0 +1,37 @@
|
||||
CREATE TABLE IF NOT EXISTS gateway_debug_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
|
||||
upstream_request_id TEXT,
|
||||
account_id BIGINT,
|
||||
account_email TEXT,
|
||||
account_platform TEXT,
|
||||
|
||||
event_type TEXT NOT NULL DEFAULT 'api_call',
|
||||
|
||||
method TEXT,
|
||||
full_url TEXT,
|
||||
request_headers TEXT,
|
||||
request_body TEXT,
|
||||
request_size INTEGER,
|
||||
|
||||
response_status INTEGER,
|
||||
response_headers TEXT,
|
||||
response_body_preview TEXT,
|
||||
response_size INTEGER,
|
||||
|
||||
model_requested TEXT,
|
||||
model_upstream TEXT,
|
||||
is_stream BOOLEAN DEFAULT FALSE,
|
||||
duration_ms INTEGER,
|
||||
|
||||
tls_profile TEXT,
|
||||
|
||||
error_message TEXT,
|
||||
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_account_id ON gateway_debug_logs (account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_created_at ON gateway_debug_logs (created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_event_type ON gateway_debug_logs (event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_model ON gateway_debug_logs (model_requested);
|
||||
70
backend/migrations/083_reconcile_gateway_debug_logs.sql
Normal file
70
backend/migrations/083_reconcile_gateway_debug_logs.sql
Normal file
@ -0,0 +1,70 @@
|
||||
CREATE TABLE IF NOT EXISTS gateway_debug_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
upstream_request_id TEXT,
|
||||
account_id BIGINT,
|
||||
account_email TEXT,
|
||||
account_platform TEXT,
|
||||
event_type TEXT NOT NULL DEFAULT 'api_call',
|
||||
method TEXT,
|
||||
full_url TEXT,
|
||||
request_headers TEXT,
|
||||
request_body TEXT,
|
||||
request_size INTEGER,
|
||||
response_status INTEGER,
|
||||
response_headers TEXT,
|
||||
response_body_preview TEXT,
|
||||
response_size INTEGER,
|
||||
model_requested TEXT,
|
||||
model_upstream TEXT,
|
||||
is_stream BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
duration_ms INTEGER,
|
||||
tls_profile TEXT,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS upstream_request_id TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS account_id BIGINT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS account_email TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS account_platform TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS event_type TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS method TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS full_url TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS request_headers TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS request_body TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS request_size INTEGER;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS response_status INTEGER;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS response_headers TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS response_body_preview TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS response_size INTEGER;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS model_requested TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS model_upstream TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS is_stream BOOLEAN;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS duration_ms INTEGER;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS tls_profile TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS error_message TEXT;
|
||||
ALTER TABLE gateway_debug_logs ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ;
|
||||
|
||||
UPDATE gateway_debug_logs
|
||||
SET event_type = 'api_call'
|
||||
WHERE event_type IS NULL;
|
||||
|
||||
UPDATE gateway_debug_logs
|
||||
SET is_stream = FALSE
|
||||
WHERE is_stream IS NULL;
|
||||
|
||||
UPDATE gateway_debug_logs
|
||||
SET created_at = NOW()
|
||||
WHERE created_at IS NULL;
|
||||
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN event_type SET DEFAULT 'api_call';
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN event_type SET NOT NULL;
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN is_stream SET DEFAULT FALSE;
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN is_stream SET NOT NULL;
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN created_at SET DEFAULT NOW();
|
||||
ALTER TABLE gateway_debug_logs ALTER COLUMN created_at SET NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_account_id ON gateway_debug_logs (account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_created_at ON gateway_debug_logs (created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_event_type ON gateway_debug_logs (event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_gdl_model ON gateway_debug_logs (model_requested);
|
||||
@ -2500,6 +2500,57 @@
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"gemini-3-flash": {
|
||||
"cache_read_input_token_cost": 5e-08,
|
||||
"cache_read_input_token_cost_priority": 9e-08,
|
||||
"input_cost_per_audio_token": 1e-06,
|
||||
"input_cost_per_audio_token_priority": 1.8e-06,
|
||||
"input_cost_per_token": 5e-07,
|
||||
"input_cost_per_token_priority": 9e-07,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 65535,
|
||||
"max_pdf_size_mb": 30,
|
||||
"max_tokens": 65535,
|
||||
"max_video_length": 1,
|
||||
"max_videos_per_prompt": 10,
|
||||
"mode": "chat",
|
||||
"output_cost_per_reasoning_token": 3e-06,
|
||||
"output_cost_per_token": 3e-06,
|
||||
"output_cost_per_token_priority": 5.4e-06,
|
||||
"source": "https://ai.google.dev/pricing/gemini-3",
|
||||
"supported_endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/batch"
|
||||
],
|
||||
"supported_modalities": [
|
||||
"text",
|
||||
"image",
|
||||
"audio",
|
||||
"video"
|
||||
],
|
||||
"supported_output_modalities": [
|
||||
"text"
|
||||
],
|
||||
"supports_audio_output": false,
|
||||
"supports_function_calling": true,
|
||||
"supports_native_streaming": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_pdf_input": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_service_tier": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_url_context": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"gemini-3-flash-preview": {
|
||||
"cache_read_input_token_cost": 5e-08,
|
||||
"cache_read_input_token_cost_priority": 9e-08,
|
||||
|
||||
@ -20,6 +20,9 @@ SERVER_PORT=8080
|
||||
# Server mode: release or debug
|
||||
SERVER_MODE=release
|
||||
|
||||
# Main application image override
|
||||
SUB2API_IMAGE=zfc931912343/sub2api:latest
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging Configuration
|
||||
# 日志配置
|
||||
@ -389,3 +392,32 @@ OPS_ENABLED=true
|
||||
# Leave empty for direct connection (recommended for overseas servers)
|
||||
# 留空表示直连(适用于海外服务器)
|
||||
UPDATE_PROXY_URL=
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Language Server Pool Mode (Enhanced Security)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Enable to route requests through real AntiGravity LS binary
|
||||
# Makes upstream traffic indistinguishable from real IDE
|
||||
# ANTIGRAVITY_LS_MODE=true
|
||||
# LS replicas per account. Default is 5.
|
||||
# Increase for higher concurrency, but each replica is an extra LS process.
|
||||
# ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=5
|
||||
# Optional global fallback proxy for accounts without dedicated LS proxy.
|
||||
# Must be socks5/socks5h in worker mode.
|
||||
ANTIGRAVITY_LS_PROXY=
|
||||
# LS routing strategy (default js-parity)
|
||||
ANTIGRAVITY_LS_STRATEGY=js-parity
|
||||
# Dynamic LS worker container image. Build/pull this image before enabling LS mode.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=zfc931912343/sub2api-lsworker:latest
|
||||
# Docker network name shared by sub2api and dynamic ls-worker containers.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=sub2api-network
|
||||
# Docker socket used by sub2api to create dynamic ls-worker containers.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=unix:///var/run/docker.sock
|
||||
# Idle TTL before worker container is reaped.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=15m
|
||||
# Maximum number of active worker containers on this node.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=50
|
||||
# Maximum time allowed for worker cold start and readiness.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=45s
|
||||
# Per-request timeout when sub2api talks to worker control API.
|
||||
GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=60s
|
||||
|
||||
@ -65,6 +65,20 @@ docker compose -f docker-compose.local.yml logs sub2api | grep "admin password"
|
||||
# http://localhost:8080
|
||||
```
|
||||
|
||||
### LS Worker Image
|
||||
|
||||
When `ANTIGRAVITY_LS_MODE=true`, Sub2API creates dynamic `ls-worker`
|
||||
containers through the Docker socket. Build or pull the worker image before
|
||||
enabling LS mode:
|
||||
|
||||
```bash
|
||||
cd /path/to/sub2api
|
||||
docker build -f deploy/lsworker.Dockerfile -t weishaw/sub2api-lsworker:latest .
|
||||
```
|
||||
|
||||
The `sub2api` container must also be able to access `/var/run/docker.sock`,
|
||||
and the shared Docker network name must remain fixed at `sub2api-network`.
|
||||
|
||||
### Method 2: Manual Deployment
|
||||
|
||||
If you prefer manual control:
|
||||
|
||||
@ -283,6 +283,30 @@ gateway:
|
||||
queue: 0.7
|
||||
error_rate: 0.8
|
||||
ttft: 0.5
|
||||
# Antigravity LS worker container configuration
|
||||
# Antigravity LS worker 容器控制平面配置
|
||||
antigravity_ls_worker:
|
||||
# Worker image used by sub2api to create dynamic LS containers
|
||||
# sub2api 用于创建动态 LS worker 的镜像
|
||||
image: "weishaw/sub2api-lsworker:latest"
|
||||
# Docker network name shared by sub2api and workers
|
||||
# sub2api 与 worker 共享的 Docker network 名称
|
||||
network: "sub2api-network"
|
||||
# Docker socket path or host used by sub2api control plane
|
||||
# sub2api 控制面访问的 Docker socket / host
|
||||
docker_socket: "unix:///var/run/docker.sock"
|
||||
# Idle TTL before a worker container is recycled
|
||||
# worker 容器空闲回收时间
|
||||
idle_ttl: 15m
|
||||
# Max active worker containers per node
|
||||
# 单节点最大 worker 容器数量
|
||||
max_active: 50
|
||||
# Worker cold-start timeout
|
||||
# worker 冷启动超时
|
||||
startup_timeout: 45s
|
||||
# Timeout for control-plane calls from sub2api to worker
|
||||
# sub2api 调用 worker 控制接口的超时
|
||||
request_timeout: 60s
|
||||
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
|
||||
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
|
||||
# Max idle connections across all hosts
|
||||
|
||||
@ -36,6 +36,7 @@ services:
|
||||
volumes:
|
||||
# Local directory mapping for easy migration
|
||||
- ./data:/app/data
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# Optional: Mount custom config.yaml (uncomment and create the file first)
|
||||
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
|
||||
# - ./config.yaml:/app/data/config.yaml
|
||||
@ -128,6 +129,22 @@ services:
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# =======================================================================
|
||||
# Language Server Worker Mode
|
||||
# =======================================================================
|
||||
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
|
||||
- ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
|
||||
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
|
||||
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
|
||||
|
||||
# =======================================================================
|
||||
# Security Configuration (URL Allowlist)
|
||||
# =======================================================================
|
||||
@ -230,4 +247,5 @@ services:
|
||||
# =============================================================================
|
||||
networks:
|
||||
sub2api-network:
|
||||
name: sub2api-network
|
||||
driver: bridge
|
||||
|
||||
@ -16,7 +16,8 @@ services:
|
||||
# Sub2API Application
|
||||
# ===========================================================================
|
||||
sub2api:
|
||||
image: weishaw/sub2api:latest
|
||||
# Override with SUB2API_IMAGE to use a private registry or pinned tag.
|
||||
image: ${SUB2API_IMAGE:-weishaw/sub2api:latest}
|
||||
container_name: sub2api
|
||||
restart: unless-stopped
|
||||
ulimits:
|
||||
@ -28,6 +29,7 @@ services:
|
||||
volumes:
|
||||
# Data persistence (config.yaml will be auto-generated here)
|
||||
- sub2api_data:/app/data
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# Optional: Mount custom config.yaml (uncomment and create the file first)
|
||||
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
|
||||
# - ./config.yaml:/app/data/config.yaml
|
||||
@ -120,6 +122,26 @@ services:
|
||||
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
|
||||
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
|
||||
|
||||
# =======================================================================
|
||||
# Language Server Pool Mode (Enhanced Security)
|
||||
# =======================================================================
|
||||
# Enable to route requests through real LS binary (Google's own code)
|
||||
# This makes upstream traffic indistinguishable from real IDE
|
||||
- ANTIGRAVITY_LS_MODE=${ANTIGRAVITY_LS_MODE:-false}
|
||||
- ANTIGRAVITY_APP_ROOT=/app/ls
|
||||
# SOCKS5/HTTP proxy fallback used when account has no dedicated LS proxy
|
||||
- ANTIGRAVITY_LS_PROXY=${ANTIGRAVITY_LS_PROXY:-}
|
||||
- ANTIGRAVITY_LS_STRATEGY=${ANTIGRAVITY_LS_STRATEGY:-js-parity}
|
||||
- ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT=${ANTIGRAVITY_LS_REPLICAS_PER_ACCOUNT:-5}
|
||||
# Keep the worker image aligned with the main image release when overriding.
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE=${GATEWAY_ANTIGRAVITY_LS_WORKER_IMAGE:-weishaw/sub2api-lsworker:latest}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK=${GATEWAY_ANTIGRAVITY_LS_WORKER_NETWORK:-sub2api-network}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET=${GATEWAY_ANTIGRAVITY_LS_WORKER_DOCKER_SOCKET:-unix:///var/run/docker.sock}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL=${GATEWAY_ANTIGRAVITY_LS_WORKER_IDLE_TTL:-15m}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE=${GATEWAY_ANTIGRAVITY_LS_WORKER_MAX_ACTIVE:-50}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_STARTUP_TIMEOUT:-45s}
|
||||
- GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT=${GATEWAY_ANTIGRAVITY_LS_WORKER_REQUEST_TIMEOUT:-60s}
|
||||
|
||||
# =======================================================================
|
||||
# Security Configuration (URL Allowlist)
|
||||
# =======================================================================
|
||||
@ -234,4 +256,5 @@ volumes:
|
||||
# =============================================================================
|
||||
networks:
|
||||
sub2api-network:
|
||||
name: sub2api-network
|
||||
driver: bridge
|
||||
|
||||
@ -8,9 +8,27 @@ if [ "$(id -u)" = "0" ]; then
|
||||
mkdir -p /app/data
|
||||
# Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro)
|
||||
chown -R sub2api:sub2api /app/data 2>/dev/null || true
|
||||
if [ -S /var/run/docker.sock ]; then
|
||||
DOCKER_GID="$(stat -c '%g' /var/run/docker.sock 2>/dev/null || true)"
|
||||
if [ -n "${DOCKER_GID}" ]; then
|
||||
DOCKER_GROUP="$(getent group "${DOCKER_GID}" | cut -d: -f1 || true)"
|
||||
if [ -z "${DOCKER_GROUP}" ]; then
|
||||
DOCKER_GROUP="dockersock"
|
||||
groupadd -for -g "${DOCKER_GID}" "${DOCKER_GROUP}" 2>/dev/null || true
|
||||
fi
|
||||
usermod -aG "${DOCKER_GROUP}" sub2api 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
# Re-invoke this script as sub2api so the flag-detection below
|
||||
# also runs under the correct user.
|
||||
exec su-exec sub2api "$0" "$@"
|
||||
# Use gosu if available (Debian), fall back to su-exec (Alpine)
|
||||
if command -v gosu >/dev/null 2>&1; then
|
||||
exec gosu sub2api "$0" "$@"
|
||||
elif command -v su-exec >/dev/null 2>&1; then
|
||||
exec su-exec sub2api "$0" "$@"
|
||||
else
|
||||
exec su -s /bin/sh sub2api -c "exec $0 $*"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Compatibility: if the first arg looks like a flag (e.g. --help),
|
||||
|
||||
21
deploy/ls-bin/cert.pem
Normal file
21
deploy/ls-bin/cert.pem
Normal file
@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDXTCCAkWgAwIBAgIUVoRddTlTFh3+shRe6g4kSLo2n0MwDQYJKoZIhvcNAQEL
|
||||
BQAwSTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1FTkFCTEVTIEhUVFAy
|
||||
MRswGQYDVQQLDBJidW5kbGVkIG9uIHB1cnBvc2UwHhcNMjUwOTA0MjA1NTA0WhcN
|
||||
MjYwOTA0MjA1NTA0WjBJMRIwEAYDVQQDDAlsb2NhbGhvc3QxFjAUBgNVBAoMDUVO
|
||||
QUJMRVMgSFRUUDIxGzAZBgNVBAsMEmJ1bmRsZWQgb24gcHVycG9zZTCCASIwDQYJ
|
||||
KoZIhvcNAQEBBQADggEPADCCAQoCggEBAJVpU6IyIMgwB6CJHkOeEAgYtzvyH6fM
|
||||
lkZSbemTrD9RCWZ4Fati1/6vbbMyWsM2XNJQMhJo0JTEoLDddN1iV/xGJCO/3dgw
|
||||
4+wLqqEeck4R1pHygCkb40TycmyygSWsidkEUH0xp51nCapIdPr/WL6O+Gbpl6DA
|
||||
onerUmWIO39VG2SpV7x3iXZOSbIGMsOiNZBmGwBZcL8ZejBIDjwvNjnX/d2tejH5
|
||||
/Mo4KVEXl5jsqaNbDIkhSs5BXtCMhoi1dqt75M8FyuNZd50AGFSa9Lj6pHTpwepD
|
||||
k2x4h+czPcvscF7TQG31TK1VYFPUThDim+by0+LQKkpy/UGVWnbC4dsCAwEAAaM9
|
||||
MDswGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMB0GA1UdDgQWBBSonSKmHCVt
|
||||
yBoVH1xEb3vtCng80DANBgkqhkiG9w0BAQsFAAOCAQEAinBO/uYe8ExHeiskt2P/
|
||||
Oxkd5sHSY9deLVuyX/TFnUEfktMfYKM2Juy+MfH4vfrcEhYkYJJcm25UGrtiT0Jh
|
||||
bUooDkR53549Xzg/70HU/ls1eNIe0zYqmS12H5W4Q1LAWTVpePscB4dgOrps6xIk
|
||||
Q4nlF7dst93E3swAe81rgCEd7VZEZy5VQcE9K+CIZXaAUJwUAsAtJbrP+5JMe9pt
|
||||
q52Zq5ZVkBS+4xeaMrasN0iTgsS4Lxo2a0GFDIJ84V66oeX7a5SXfSNn7rMVIDai
|
||||
KNZ2Cf2xNXUwq25Z6tjpQCqwYn3SE8b/Yi6fFZmy5D8kmY7dMh8ghVOc7rD+Vsk6
|
||||
/Q==
|
||||
-----END CERTIFICATE-----
|
||||
BIN
deploy/ls-bin/language_server_linux_arm
Executable file
BIN
deploy/ls-bin/language_server_linux_arm
Executable file
Binary file not shown.
BIN
deploy/ls-bin/language_server_linux_x64
Executable file
BIN
deploy/ls-bin/language_server_linux_x64
Executable file
Binary file not shown.
70
deploy/lsworker-entrypoint.sh
Normal file
70
deploy/lsworker-entrypoint.sh
Normal file
@ -0,0 +1,70 @@
|
||||
#!/bin/sh
|
||||
set -eu
|
||||
|
||||
PROXY_HOST="${LSWORKER_PROXY_HOST:-}"
|
||||
PROXY_PORT="${LSWORKER_PROXY_PORT:-1080}"
|
||||
PROXY_USER="${LSWORKER_PROXY_USER:-}"
|
||||
PROXY_PASS="${LSWORKER_PROXY_PASS:-}"
|
||||
CONTROL_PORT="${LSWORKER_CONTROL_PORT:-18081}"
|
||||
REDSOCKS_PORT="${LSWORKER_REDSOCKS_PORT:-12345}"
|
||||
NETWORK_READY_FILE="${LSWORKER_NETWORK_READY_FILE:-/run/lsworker/network-ready}"
|
||||
|
||||
mkdir -p "$(dirname "${NETWORK_READY_FILE}")"
|
||||
|
||||
if [ -z "${PROXY_HOST}" ]; then
|
||||
echo "LSWORKER_PROXY_HOST is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROXY_IP="$(getent ahostsv4 "${PROXY_HOST}" | awk 'NR==1 {print $1}')"
|
||||
if [ -z "${PROXY_IP}" ]; then
|
||||
echo "failed to resolve proxy host: ${PROXY_HOST}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat >/tmp/redsocks.conf <<EOF
|
||||
base {
|
||||
log_debug = off;
|
||||
log_info = on;
|
||||
daemon = off;
|
||||
redirector = iptables;
|
||||
}
|
||||
|
||||
redsocks {
|
||||
local_ip = 0.0.0.0;
|
||||
local_port = ${REDSOCKS_PORT};
|
||||
ip = ${PROXY_IP};
|
||||
port = ${PROXY_PORT};
|
||||
type = socks5;
|
||||
EOF
|
||||
|
||||
if [ -n "${PROXY_USER}" ]; then
|
||||
printf ' login = "%s";\n' "${PROXY_USER}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
if [ -n "${PROXY_PASS}" ]; then
|
||||
printf ' password = "%s";\n' "${PROXY_PASS}" >>/tmp/redsocks.conf
|
||||
fi
|
||||
|
||||
cat >>/tmp/redsocks.conf <<EOF
|
||||
}
|
||||
EOF
|
||||
|
||||
redsocks -c /tmp/redsocks.conf >/tmp/redsocks.log 2>&1 &
|
||||
REDSOCKS_PID="$!"
|
||||
trap 'kill "${REDSOCKS_PID}" >/dev/null 2>&1 || true' EXIT
|
||||
|
||||
sleep 1
|
||||
|
||||
iptables -t nat -N REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -F REDSOCKS
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.0/8 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d 127.0.0.11/32 -j RETURN
|
||||
iptables -t nat -A REDSOCKS -d "${PROXY_IP}/32" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp --dport "${CONTROL_PORT}" -j RETURN
|
||||
iptables -t nat -A REDSOCKS -p tcp -j REDIRECT --to-ports "${REDSOCKS_PORT}"
|
||||
iptables -t nat -D OUTPUT -p tcp -j REDSOCKS 2>/dev/null || true
|
||||
iptables -t nat -A OUTPUT -p tcp -j REDSOCKS
|
||||
|
||||
touch "${NETWORK_READY_FILE}"
|
||||
|
||||
exec gosu sub2api /app/lsworker
|
||||
52
deploy/lsworker.Dockerfile
Normal file
52
deploy/lsworker.Dockerfile
Normal file
@ -0,0 +1,52 @@
|
||||
ARG GOLANG_IMAGE=golang:1.26.1-alpine
|
||||
ARG DEBIAN_IMAGE=debian:bookworm-slim
|
||||
|
||||
FROM ${GOLANG_IMAGE} AS builder
|
||||
|
||||
WORKDIR /app/backend
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
COPY backend/go.mod backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY backend/ ./
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -trimpath -ldflags="-s -w" -o /app/lsworker ./cmd/lsworker
|
||||
|
||||
FROM ${DEBIAN_IMAGE}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
gosu \
|
||||
iproute2 \
|
||||
iptables \
|
||||
redsocks \
|
||||
tzdata \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN groupadd -g 1000 sub2api && \
|
||||
useradd -u 1000 -g sub2api -m -s /bin/sh sub2api
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/lsworker /app/lsworker
|
||||
COPY deploy/ls-bin/language_server_linux_* /tmp/ls-bin/
|
||||
COPY deploy/ls-bin/cert.pem /app/ls/extensions/antigravity/dist/languageServer/
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN mkdir -p /app/ls/extensions/antigravity/bin /run/lsworker && \
|
||||
if [ "${TARGETARCH:-amd64}" = "arm64" ]; then \
|
||||
cp /tmp/ls-bin/language_server_linux_arm /app/ls/extensions/antigravity/bin/language_server_linux_arm; \
|
||||
else \
|
||||
cp /tmp/ls-bin/language_server_linux_x64 /app/ls/extensions/antigravity/bin/language_server_linux_x64; \
|
||||
fi && \
|
||||
chmod +x /app/lsworker /app/ls/extensions/antigravity/bin/language_server_linux_* && \
|
||||
chown -R sub2api:sub2api /app /run/lsworker && \
|
||||
rm -rf /tmp/ls-bin
|
||||
|
||||
COPY deploy/lsworker-entrypoint.sh /app/lsworker-entrypoint.sh
|
||||
RUN chmod +x /app/lsworker-entrypoint.sh
|
||||
|
||||
EXPOSE 18081
|
||||
|
||||
ENTRYPOINT ["/app/lsworker-entrypoint.sh"]
|
||||
@ -510,6 +510,7 @@ const handleEvent = (event: {
|
||||
addLine(streamingContent.value, 'text-green-300')
|
||||
streamingContent.value = ''
|
||||
}
|
||||
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -510,6 +510,7 @@ const handleEvent = (event: {
|
||||
addLine(streamingContent.value, 'text-green-300')
|
||||
streamingContent.value = ''
|
||||
}
|
||||
addLine(`Error: ${errorMessage.value}`, 'text-red-400')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,4 +144,28 @@ describe('AccountTestModal', () => {
|
||||
expect(preview.exists()).toBe(true)
|
||||
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
|
||||
})
|
||||
|
||||
it('收到 error 事件时会把错误内容显示在终端输出里', async () => {
|
||||
;(global.fetch as any).mockResolvedValueOnce(
|
||||
createStreamResponse([
|
||||
'data: {"type":"test_start","model":"claude-opus-4-6"}\n',
|
||||
'data: {"type":"error","error":"API returned 429: You have exhausted your capacity on this model."}\n'
|
||||
])
|
||||
)
|
||||
|
||||
const wrapper = mountModal()
|
||||
await wrapper.setProps({ show: true })
|
||||
await flushPromises()
|
||||
|
||||
const buttons = wrapper.findAll('button')
|
||||
const startButton = buttons.find((button) => button.text().includes('admin.accounts.startTest'))
|
||||
expect(startButton).toBeTruthy()
|
||||
|
||||
await startButton!.trigger('click')
|
||||
await flushPromises()
|
||||
await flushPromises()
|
||||
|
||||
expect(wrapper.text()).toContain('API returned 429: You have exhausted your capacity on this model.')
|
||||
expect(wrapper.text()).toContain('Error: API returned 429: You have exhausted your capacity on this model.')
|
||||
})
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user