chore: 删除 Antigravity 订制代码,回退至上游 v0.1.118
Some checks failed
CI / test (push) Failing after 3s
CI / frontend (push) Failing after 4s
CI / golangci-lint (push) Failing after 6s
CI / windsurf-platform (macos-latest) (push) Has been cancelled
CI / windsurf-platform (windows-latest) (push) Has been cancelled
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 3s

- 删除自定义文件:gateway_attribution, gateway_claude_runtime_headers,
  identity_service_antigravity, language_server_service, lsrpc_handler,
  antigravity_http handler/routes, 所有 antigravity 专项测试
- 将 antigravity pkg/service 文件回退至上游版本(移除 IsEnterprise、
  claude_code_tool_map、dynamic fingerprint 等定制逻辑)
- 修复 gateway_service.go:移除 NormalizeSystemPromptEnv、
  generateSessionIDForAccount、applyClaudeRuntimeOptionalHeaders 调用,
  使用上游的 session-id 同步逻辑
- 恢复 language_server_pb gen 文件(Windsurf local_ls.go 依赖)
- 保留全部 Windsurf 集成代码不变
This commit is contained in:
win 2026-04-25 22:35:48 +08:00
parent 2064c1a19f
commit 898a65314c
50 changed files with 235 additions and 5728 deletions

View File

@ -8,6 +8,11 @@ package main
import (
"context"
"log"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
@ -18,14 +23,9 @@ import (
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"log"
"net/http"
"sync"
"time"
)
import (
_ "embed"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
)
@ -257,9 +257,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
langServerService := service.ProvideLanguageServerService(httpUpstream, antigravityGatewayService, accountRepository)
lsrpcHandler := service.NewLSRPCHandler(antigravityGatewayService, accountRepository, nil)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, langServerService, lsrpcHandler)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
httpServer := server.ProvideHTTPServer(configConfig, engine)
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
@ -271,7 +269,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
<<<<<<< HEAD
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
application := &Application{

View File

@ -1,114 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
func repeatStr(s string, count int) string {
return strings.Repeat(s, count)
}
func main() {
accessToken := flag.String("token", "", "OAuth access token")
projectID := flag.String("project", "", "Project ID")
proxyURL := flag.String("proxy", "", "Proxy URL (optional)")
flag.Parse()
if *accessToken == "" || *projectID == "" {
log.Fatal("missing required flags: -token and -project")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
client, err := antigravity.NewClient(*proxyURL)
if err != nil {
log.Fatalf("failed to create client: %v", err)
}
fmt.Println(repeatStr("=", 80))
fmt.Println("Antigravity Privacy Setup Diagnostic Test")
fmt.Println(repeatStr("=", 80))
// Step 1: Verify token is valid by fetching user info
fmt.Println("\n[Step 1] Verifying access token...")
userInfo, err := client.GetUserInfo(ctx, *accessToken)
if err != nil {
log.Fatalf("failed to get user info: %v", err)
}
fmt.Printf("✓ Email: %s\n", userInfo.Email)
// Step 2: Call SetUserSettings
fmt.Println("\n[Step 2] Calling SetUserSettings (clear privacy settings)...")
setResp, err := client.SetUserSettings(ctx, *accessToken)
if err != nil {
log.Fatalf("SetUserSettings failed: %v", err)
}
if setResp.IsSuccess() {
fmt.Println("✓ SetUserSettings succeeded")
fmt.Printf(" Response: %+v\n", setResp)
} else {
fmt.Println("✗ SetUserSettings returned non-empty userSettings")
fmt.Printf(" Response: %+v\n", setResp)
fmt.Println("\n ERROR: This indicates privacy settings were NOT cleared!")
fmt.Println(" Possible causes:")
fmt.Println(" 1. Account restrictions on privacy settings")
fmt.Println(" 2. Account still has telemetryEnabled=true")
fmt.Println(" 3. API response indicates settings persist")
}
// Step 3: Verify by calling FetchUserInfo
fmt.Println("\n[Step 3] Calling FetchUserInfo to verify privacy status...")
userInfoResp, err := client.FetchUserInfo(ctx, *accessToken, *projectID)
if err != nil {
log.Fatalf("FetchUserInfo failed: %v", err)
}
if userInfoResp.IsPrivate() {
fmt.Println("✓ Privacy is properly set (userSettings is empty)")
fmt.Printf(" Response: %+v\n", userInfoResp)
} else {
fmt.Println("✗ Privacy is NOT properly set (userSettings contains telemetryEnabled)")
fmt.Printf(" Response: %+v\n", userInfoResp)
fmt.Println("\n ERROR: This explains the 503 errors in gateway!")
fmt.Println(" Reason: Antigravity API rejects requests from accounts with")
fmt.Println(" telemetryEnabled=true to protect user privacy")
}
// Summary
fmt.Println("\n" + repeatStr("=", 80))
fmt.Println("DIAGNOSIS SUMMARY")
fmt.Println(repeatStr("=", 80))
if setResp.IsSuccess() && userInfoResp.IsPrivate() {
fmt.Println("✓ Privacy setup is SUCCESSFUL")
fmt.Println(" This account should NOT experience 503 errors due to privacy")
fmt.Println(" The 503 errors might be due to:")
fmt.Println(" 1. Temporary API outages")
fmt.Println(" 2. Rate limiting on new accounts")
fmt.Println(" 3. Other infrastructure issues")
} else if !setResp.IsSuccess() && !userInfoResp.IsPrivate() {
fmt.Println("✗ Privacy setup FAILED")
fmt.Println(" The account cannot clear privacy settings on Antigravity")
fmt.Println(" This causes the 503 Service Unavailable errors")
fmt.Println("\nSOLUTION:")
fmt.Println(" 1. Check if this is a restricted account type")
fmt.Println(" 2. Try re-authorizing the account")
fmt.Println(" 3. Check Antigravity API rate limiting")
fmt.Println(" 4. Inspect firewall/proxy settings")
} else {
fmt.Println("⚠ INCONSISTENT STATE:")
fmt.Println(" SetUserSettings and FetchUserInfo returned different results")
fmt.Println(" This might indicate a transient API issue or data sync delay")
}
fmt.Println("\n" + repeatStr("=", 80))
}

View File

@ -1,316 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// TestScenario 定义一个测试场景
type TestScenario struct {
name string
description string
testFunc func(ctx context.Context, token, projectID string) (bool, string)
}
var scenarios []TestScenario
func init() {
scenarios = []TestScenario{
{
name: "single_request",
description: "单次请求 - 检查是否立即成功",
testFunc: testSingleRequest,
},
{
name: "sequential_requests",
description: "顺序发送 10 个请求 - 找到稳定点",
testFunc: testSequentialRequests,
},
{
name: "concurrent_requests",
description: "并发发送 5 个请求 - 检查并发初始化行为",
testFunc: testConcurrentRequests,
},
{
name: "warmup_then_request",
description: "预热(模型列表请求) + 业务请求 - 验证预热效果",
testFunc: testWarmupThenRequest,
},
{
name: "delayed_request",
description: "延迟 5 秒后请求 - 检查账号初始化时间",
testFunc: testDelayedRequest,
},
}
}
// testSingleRequest 单次请求
func testSingleRequest(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
start := time.Now()
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
elapsed := time.Since(start)
if err != nil {
return false, fmt.Sprintf("请求失败 (%v): %v", elapsed, err)
}
if resp == nil {
return false, fmt.Sprintf("响应为空 (%v)", elapsed)
}
return true, fmt.Sprintf("✓ 单次请求成功 - 耗时 %v", elapsed)
}
// testSequentialRequests 顺序发送多个请求
func testSequentialRequests(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
var firstFailIdx = -1
var firstSuccessIdx = -1
var timings []time.Duration
for i := 0; i < 10; i++ {
start := time.Now()
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
elapsed := time.Since(start)
timings = append(timings, elapsed)
success := err == nil && resp != nil
fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", i+1, elapsed, map[bool]string{true: "✓", false: "✗"}[success])
if !success && firstFailIdx == -1 {
firstFailIdx = i
}
if success && firstSuccessIdx == -1 {
firstSuccessIdx = i
}
}
var report string
if firstSuccessIdx == -1 {
report = "✗ 全部失败"
} else if firstSuccessIdx == 0 {
report = fmt.Sprintf("✓ 首次即成功 (耗时 %v)", timings[0])
} else {
report = fmt.Sprintf("⚠ 第 %d 次才成功 (失败 %d 次), 首次耗时 %v",
firstSuccessIdx+1, firstSuccessIdx, timings[firstSuccessIdx])
}
return firstSuccessIdx >= 0, report
}
// testConcurrentRequests 并发请求
func testConcurrentRequests(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
var wg sync.WaitGroup
results := make([]bool, 5)
timings := make([]time.Duration, 5)
mu := sync.Mutex{}
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
start := time.Now()
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
elapsed := time.Since(start)
mu.Lock()
results[idx] = err == nil && resp != nil
timings[idx] = elapsed
mu.Unlock()
fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", idx+1, elapsed, map[bool]string{true: "✓", false: "✗"}[results[idx]])
}(i)
}
wg.Wait()
successCount := 0
for _, ok := range results {
if ok {
successCount++
}
}
return successCount > 0, fmt.Sprintf("%d/5 并发请求成功", successCount)
}
// testWarmupThenRequest 预热测试
func testWarmupThenRequest(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
// 第 1 步:预热 - 调用 LoadCodeAssist获取项目信息
fmt.Println(" [Warmup] 调用 LoadCodeAssist 预热...")
warmupStart := time.Now()
_, _, warmupErr := client.LoadCodeAssist(ctx, token)
warmupElapsed := time.Since(warmupStart)
fmt.Printf(" [Warmup] 耗时 %v, 状态: %v\n", warmupElapsed, map[bool]string{true: "✓", false: "✗"}[warmupErr == nil])
// 第 2 步:实际请求
fmt.Println(" [Request] 发送业务请求...")
reqStart := time.Now()
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
reqElapsed := time.Since(reqStart)
success := err == nil && resp != nil
fmt.Printf(" [Request] 耗时 %v, 状态: %v\n", reqElapsed, map[bool]string{true: "✓", false: "✗"}[success])
return success, fmt.Sprintf("预热 %v + 请求 %v = 总耗时 %v",
warmupElapsed, reqElapsed, warmupElapsed+reqElapsed)
}
// testDelayedRequest 延迟请求
func testDelayedRequest(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
fmt.Println(" 等待 5 秒...")
time.Sleep(5 * time.Second)
start := time.Now()
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
elapsed := time.Since(start)
success := err == nil && resp != nil
return success, fmt.Sprintf("延迟 5s 后请求 - 耗时 %v, 状态: %v", elapsed, map[bool]string{true: "✓", false: "✗"}[success])
}
// testOAuthTokenRefresh OAuth Token 刷新测试
func testOAuthTokenRefresh(ctx context.Context, refreshToken string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
start := time.Now()
tokenInfo, err := client.RefreshToken(ctx, refreshToken, false)
elapsed := time.Since(start)
if err != nil {
return false, fmt.Sprintf("Token 刷新失败 (%v): %v", elapsed, err)
}
return true, fmt.Sprintf("✓ Token 刷新成功 - 耗时 %v, 新 Token 有效期: %d 秒",
elapsed, tokenInfo.ExpiresIn)
}
// testAccountInitializationWarmup 账号初始化预热
func testAccountInitializationWarmup(ctx context.Context, token, projectID string) (bool, string) {
client, err := antigravity.NewClient("")
if err != nil {
return false, fmt.Sprintf("创建客户端失败: %v", err)
}
fmt.Println(" 执行完整的账号初始化流程...")
// 1. GetUserInfo
fmt.Println(" 1. GetUserInfo...")
start := time.Now()
_, err1 := client.GetUserInfo(ctx, token)
fmt.Printf(" 耗时: %v\n", time.Since(start))
// 2. LoadCodeAssist
fmt.Println(" 2. LoadCodeAssist...")
start = time.Now()
_, _, err2 := client.LoadCodeAssist(ctx, token)
fmt.Printf(" 耗时: %v\n", time.Since(start))
// 3. FetchAvailableModels
fmt.Println(" 3. FetchAvailableModels...")
start = time.Now()
_, _, err3 := client.FetchAvailableModels(ctx, token, projectID)
elapsed := time.Since(start)
fmt.Printf(" 耗时: %v\n", elapsed)
success := err1 == nil && err2 == nil && err3 == nil
return success, fmt.Sprintf("账号初始化预热 - 状态: %v", map[bool]string{true: "✓", false: "✗"}[success])
}
func main() {
accessToken := flag.String("token", "", "OAuth access token")
projectID := flag.String("project", "", "Project ID")
refreshToken := flag.String("refresh", "", "Refresh token (optional)")
testName := flag.String("test", "all", "测试名称 (all, single_request, sequential_requests, etc.)")
flag.Parse()
if *accessToken == "" || *projectID == "" {
log.Fatal("缺少必需参数: -token 和 -project")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
fmt.Println("\n" + repeatStr("=", 80))
fmt.Println("Antigravity 账号初始化诊断测试套件")
fmt.Println(repeatStr("=", 80) + "\n")
// Token 刷新测试
if *refreshToken != "" {
fmt.Println("[Token 刷新测试]")
_, report := testOAuthTokenRefresh(ctx, *refreshToken)
fmt.Printf("%s\n\n", report)
}
// 账号初始化预热测试
fmt.Println("[账号初始化预热]")
_, report := testAccountInitializationWarmup(ctx, *accessToken, *projectID)
fmt.Printf("%s\n\n", report)
// 运行指定的测试
if *testName == "all" {
for _, scenario := range scenarios {
fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description)
_, report := scenario.testFunc(ctx, *accessToken, *projectID)
fmt.Printf("结果: %s\n\n", report)
}
} else {
found := false
for _, scenario := range scenarios {
if scenario.name == *testName {
found = true
fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description)
_, report := scenario.testFunc(ctx, *accessToken, *projectID)
fmt.Printf("结果: %s\n\n", report)
break
}
}
if !found {
log.Fatalf("未找到测试: %s", *testName)
}
}
fmt.Println(repeatStr("=", 80))
fmt.Println("诊断完成")
fmt.Println(repeatStr("=", 80))
}
func repeatStr(s string, count int) string {
result := ""
for i := 0; i < count; i++ {
result += s
}
return result
}

View File

@ -7,13 +7,14 @@
package language_server_pb
import (
reflect "reflect"
sync "sync"
unsafe "unsafe"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/known/emptypb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (

View File

@ -5,12 +5,13 @@
package language_server_pbconnect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
http "net/http"
strings "strings"
connect "connectrpc.com/connect"
language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are

View File

@ -15,8 +15,7 @@ func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAut
}
type AntigravityGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
IsEnterprise bool `json:"is_enterprise"`
ProxyID *int64 `json:"proxy_id"`
}
// GenerateAuthURL generates Google OAuth authorization URL
@ -28,7 +27,7 @@ func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
return
}
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.IsEnterprise)
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "生成授权链接失败: "+err.Error())
return
@ -71,7 +70,6 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
type AntigravityRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
IsEnterprise bool `json:"is_enterprise"`
}
// RefreshToken validates an Antigravity refresh token and returns full token info
@ -83,7 +81,7 @@ func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
return
}
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID, req.IsEnterprise)
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -1,267 +0,0 @@
package handler
import (
"encoding/json"
"log/slog"
"net/http"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AntigravityHTTPHandler 处理下游客户端的 HTTP 请求
// 内部调用 LanguageServerService再转发到上游 API
type AntigravityHTTPHandler struct {
langServerService *service.LanguageServerService
logger *slog.Logger
}
func NewAntigravityHTTPHandler(
langServerService *service.LanguageServerService,
logger *slog.Logger,
) *AntigravityHTTPHandler {
return &AntigravityHTTPHandler{
langServerService: langServerService,
logger: logger,
}
}
// ============================================================================
// Cascade 流程 API
// ============================================================================
// StartCascadeRequest HTTP 请求格式
type StartCascadeRequest struct {
Model string `json:"model"` // 模型名称
SystemPrompt string `json:"system_prompt"` // 系统提示
Metadata map[string]string `json:"metadata"` // 设备指纹等伪装信息
}
// StartCascadeResponse HTTP 响应格式
type StartCascadeResponse struct {
CascadeID string `json:"cascade_id"`
}
// POST /api/v1/cascade/start
// 启动新的 Cascade Agent 会话
func (h *AntigravityHTTPHandler) StartCascade(c *gin.Context) {
var req StartCascadeRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error("invalid request", "error", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid request: " + err.Error(),
})
return
}
// 提取 OAuth token
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "missing authorization header",
})
return
}
// 调用内部 LanguageServerService
cascadeID, err := h.langServerService.StartCascade(
c.Request.Context(),
req.Model,
req.SystemPrompt,
req.Metadata,
token,
)
if err != nil {
h.logger.Error("start cascade failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
return
}
h.logger.Info("cascade started", "cascade_id", cascadeID, "model", req.Model)
c.JSON(http.StatusOK, StartCascadeResponse{
CascadeID: cascadeID,
})
}
// ============================================================================
// SendUserMessageRequest HTTP 请求格式
type SendUserMessageRequest struct {
CascadeID string `json:"cascade_id"` // 会话 ID
Message string `json:"message"` // 用户消息
Context map[string]string `json:"context"` // 上下文(可选)
}
// CascadeUpdate 流式响应格式Server-Sent Events
type CascadeUpdate struct {
Type string `json:"type"` // "message_delta", "tool_call", etc.
Payload string `json:"payload"` // JSON 格式的负载
}
// POST /api/v1/cascade/message (流式)
// 发送用户消息,接收流式更新
func (h *AntigravityHTTPHandler) SendUserMessage(c *gin.Context) {
var req SendUserMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error("invalid request", "error", err)
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid request: " + err.Error(),
})
return
}
// 提取 OAuth token
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "missing authorization header",
})
return
}
// 设置 Server-Sent Events 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 调用内部 LanguageServerService获取流式响应
updateChan, err := h.langServerService.SendUserMessage(
c.Request.Context(),
req.CascadeID,
req.Message,
token,
)
if err != nil {
h.logger.Error("send user message failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
return
}
// 逐个推送更新到客户端SSE
for update := range updateChan {
data, _ := json.Marshal(update)
c.SSEvent("update", string(data))
c.Writer.Flush()
}
h.logger.Info("cascade message processed", "cascade_id", req.CascadeID)
}
// ============================================================================
// POST /api/v1/cascade/cancel
// 取消 Cascade 调用
func (h *AntigravityHTTPHandler) CancelCascade(c *gin.Context) {
var req struct {
CascadeID string `json:"cascade_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid request",
})
return
}
if err := h.langServerService.CancelCascade(
c.Request.Context(),
req.CascadeID,
); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
return
}
h.logger.Info("cascade cancelled", "cascade_id", req.CascadeID)
c.JSON(http.StatusOK, gin.H{
"success": true,
})
}
// ============================================================================
// 模型配置 API
// ============================================================================
// ModelConfig 模型配置
type ModelConfig struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
MaxTokens int `json:"max_tokens"`
SupportsThinking bool `json:"supports_thinking"`
ThinkingBudget int `json:"thinking_budget,omitempty"`
SupportsImages bool `json:"supports_images"`
Provider string `json:"provider"` // anthropic, google, openai
}
// GET /api/v1/models
// 获取可用模型列表
func (h *AntigravityHTTPHandler) GetModels(c *gin.Context) {
models, err := h.langServerService.GetAvailableModels(c.Request.Context())
if err != nil {
h.logger.Error("get models failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"models": models,
"default_model": "claude-opus-4-7",
})
}
// ============================================================================
// 健康检查 API
// ============================================================================
// GET /api/v1/health
// 健康检查
func (h *AntigravityHTTPHandler) Health(c *gin.Context) {
status, err := h.langServerService.GetStatus(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"status": "unhealthy",
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": status,
"version": "1.0.0",
})
}
// ============================================================================
// RegisterRoutes 注册所有 HTTP 路由
func (h *AntigravityHTTPHandler) RegisterRoutes(router *gin.Engine) {
api := router.Group("/api/v1")
// Cascade 流程
api.POST("/cascade/start", h.StartCascade)
api.POST("/cascade/message", h.SendUserMessage)
api.POST("/cascade/cancel", h.CancelCascade)
// 模型列表
api.GET("/models", h.GetModels)
// 健康检查
api.GET("/health", h.Health)
h.logger.Info("antigravity http handler registered",
"routes", []string{
"/api/v1/cascade/start",
"/api/v1/cascade/message",
"/api/v1/cascade/cancel",
"/api/v1/models",
"/api/v1/health",
})
}

View File

@ -270,19 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
}
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
recorder := httptest.NewRecorder()

View File

@ -1,70 +0,0 @@
package antigravity
import "strings"
var claudeCodeBuiltinToolNameMap = map[string]string{
"read": "Read",
"read_file": "Read",
"readfile": "Read",
"write": "Write",
"write_file": "Write",
"writefile": "Write",
"edit": "Edit",
"apply_patch": "Edit",
"applypatch": "Edit",
"bash": "Bash",
"execute_bash": "Bash",
"executebash": "Bash",
"exec_bash": "Bash",
"execbash": "Bash",
"glob": "Glob",
"list_files": "Glob",
"listfiles": "Glob",
"grep": "Grep",
"search_files": "Grep",
"searchfiles": "Grep",
"webfetch": "WebFetch",
"web_fetch": "WebFetch",
"fetch": "WebFetch",
"websearch": "WebSearch",
"web_search": "WebSearch",
"agent": "Agent",
"askuserquestion": "AskUserQuestion",
"ask_user_question": "AskUserQuestion",
"enterplanmode": "EnterPlanMode",
"enter_plan_mode": "EnterPlanMode",
"exitplanmode": "ExitPlanMode",
"exit_plan_mode": "ExitPlanMode",
"croncreate": "CronCreate",
"cron_create": "CronCreate",
"crondelete": "CronDelete",
"cron_delete": "CronDelete",
"schedulewakeup": "ScheduleWakeup",
"schedule_wakeup": "ScheduleWakeup",
"sendmessage": "SendMessage",
"send_message": "SendMessage",
"skill": "Skill",
"taskcreate": "TaskCreate",
"task_create": "TaskCreate",
"tasklist": "TaskList",
"task_list": "TaskList",
"taskoutput": "TaskOutput",
"task_output": "TaskOutput",
"taskstop": "TaskStop",
"task_stop": "TaskStop",
"taskupdate": "TaskUpdate",
"task_update": "TaskUpdate",
}
func normalizeClaudeCodeToolName(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return ""
}
if mapped, ok := claudeCodeBuiltinToolNameMap[strings.ToLower(trimmed)]; ok {
return mapped
}
return trimmed
}

View File

@ -1,160 +0,0 @@
package antigravity
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeClaudeCodeToolName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected string
}{
{name: "read alias", input: "read_file", expected: "Read"},
{name: "grep alias", input: "search_files", expected: "Grep"},
{name: "webfetch alias", input: "fetch", expected: "WebFetch"},
{name: "plan alias", input: "enter_plan_mode", expected: "EnterPlanMode"},
{name: "native passthrough", input: "TaskUpdate", expected: "TaskUpdate"},
{name: "mcp passthrough", input: "mcp__github__list_prs", expected: "mcp__github__list_prs"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.expected, normalizeClaudeCodeToolName(tt.input))
})
}
}
func TestBuildPartsNormalizesClaudeCodeToolNames(t *testing.T) {
t.Parallel()
toolIDToName := make(map[string]string)
assistantParts, stripped, err := buildParts(json.RawMessage(`[
{"type":"tool_use","id":"tool-1","name":"read_file","input":{"file_path":"/tmp/demo.txt"}}
]`), toolIDToName, false)
require.NoError(t, err)
require.False(t, stripped)
require.Len(t, assistantParts, 1)
require.NotNil(t, assistantParts[0].FunctionCall)
require.Equal(t, "Read", assistantParts[0].FunctionCall.Name)
require.Equal(t, "Read", toolIDToName["tool-1"])
userParts, stripped, err := buildParts(json.RawMessage(`[
{"type":"tool_result","tool_use_id":"tool-1","content":[{"type":"text","text":"ok"}]}
]`), toolIDToName, false)
require.NoError(t, err)
require.False(t, stripped)
require.Len(t, userParts, 1)
require.NotNil(t, userParts[0].FunctionResponse)
require.Equal(t, "Read", userParts[0].FunctionResponse.Name)
}
func TestBuildToolsNormalizesClaudeCodeBuiltinNamesOnly(t *testing.T) {
t.Parallel()
result := buildTools([]ClaudeTool{
{
Name: "search_files",
Description: "Search the workspace",
InputSchema: map[string]any{
"type": "object",
},
},
{
Name: "mcp__github__list_prs",
Description: "List pull requests",
InputSchema: map[string]any{
"type": "object",
},
},
})
require.Len(t, result, 1)
require.Len(t, result[0].FunctionDeclarations, 2)
require.Equal(t, "Grep", result[0].FunctionDeclarations[0].Name)
require.Equal(t, "mcp__github__list_prs", result[0].FunctionDeclarations[1].Name)
}
func TestNonStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) {
t.Parallel()
processor := NewNonStreamingProcessor()
response := processor.Process(&GeminiResponse{
Candidates: []GeminiCandidate{
{
Content: &GeminiContent{
Parts: []GeminiPart{
{
FunctionCall: &GeminiFunctionCall{
Name: "web_fetch",
Args: map[string]any{"url": "https://example.com"},
},
},
},
},
FinishReason: "STOP",
},
},
}, "resp-1", "claude-sonnet-4-5")
require.Len(t, response.Content, 1)
require.Equal(t, "tool_use", response.Content[0].Type)
require.Equal(t, "WebFetch", response.Content[0].Name)
require.True(t, strings.HasPrefix(response.Content[0].ID, "WebFetch-"))
require.NotNil(t, response.Content[0].Caller)
require.Equal(t, "direct", response.Content[0].Caller.Type)
require.Equal(t, "tool_use", response.StopReason)
}
func TestStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) {
t.Parallel()
processor := NewStreamingProcessor("claude-sonnet-4-5")
output := processor.processFunctionCall(&GeminiFunctionCall{
Name: "search_files",
Args: map[string]any{"pattern": "TODO"},
}, "")
events := parseSSEDataEvents(t, output)
require.Len(t, events, 3)
contentBlock, ok := events[0]["content_block"].(map[string]any)
require.True(t, ok)
require.Equal(t, "tool_use", contentBlock["type"])
require.Equal(t, "Grep", contentBlock["name"])
toolID, ok := contentBlock["id"].(string)
require.True(t, ok)
require.True(t, strings.HasPrefix(toolID, "Grep-"))
caller, ok := contentBlock["caller"].(map[string]any)
require.True(t, ok)
require.Equal(t, "direct", caller["type"])
}
func parseSSEDataEvents(t *testing.T, payload []byte) []map[string]any {
t.Helper()
lines := strings.Split(string(payload), "\n")
events := make([]map[string]any, 0)
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
var event map[string]any
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &event))
events = append(events, event)
}
return events
}

View File

@ -16,7 +16,6 @@ type ClaudeRequest struct {
TopK *int `json:"top_k,omitempty"`
Tools []ClaudeTool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
}
@ -73,10 +72,9 @@ type ContentBlock struct {
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Caller *ToolCaller `json:"caller,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
@ -116,15 +114,9 @@ type ClaudeContentItem struct {
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Caller *ToolCaller `json:"caller,omitempty"`
}
// ToolCaller Claude Code tool_use 调用来源
type ToolCaller struct {
Type string `json:"type"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
}
// ClaudeUsage Claude 用量统计

View File

@ -318,17 +318,16 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
statusCode >= 500
}
// ExchangeCode 用 authorization code 交换 token。
// isEnterprise=true 时使用企业 OAuth client_id/secret。
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, isEnterprise bool) (*TokenResponse, error) {
creds, err := GetClientCredentials(isEnterprise)
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", creds.ClientID)
params.Set("client_secret", creds.ClientSecret)
params.Set("client_id", ClientID)
params.Set("client_secret", clientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
@ -363,17 +362,16 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, is
return &tokenResp, nil
}
// RefreshToken 刷新 access_token。
// isEnterprise=true 时使用企业 OAuth client_id/secret。
func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterprise bool) (*TokenResponse, error) {
creds, err := GetClientCredentials(isEnterprise)
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", creds.ClientID)
params.Set("client_secret", creds.ClientSecret)
params.Set("client_id", ClientID)
params.Set("client_secret", clientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")
@ -406,39 +404,6 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterp
return &tokenResp, nil
}
// RefreshTokenAuto 自动判定账号类型。
// 先用个人凭证刷新;若 Google 返回 invalid_client/unauthorized_clientclient 不匹配),
// 再用企业凭证重试。返回 token 和最终判定的 isEnterprise 标志。
//
// 其他错误invalid_grant、网络错误等直接返回不重试。
func (c *Client) RefreshTokenAuto(ctx context.Context, refreshToken string) (*TokenResponse, bool, error) {
tok, err := c.RefreshToken(ctx, refreshToken, false)
if err == nil {
return tok, false, nil
}
if !isClientMismatchError(err) {
return nil, false, err
}
tok, err2 := c.RefreshToken(ctx, refreshToken, true)
if err2 == nil {
return tok, true, nil
}
// 企业也失败:返回合并后的诊断错误
return nil, false, fmt.Errorf("auto-detect refresh failed: personal=%v enterprise=%v", err, err2)
}
// isClientMismatchError 判断是否为 OAuth client 不匹配导致的错误。
// 只有这种错误才会触发"切换账号类型重试"。
func isClientMismatchError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "invalid_client") ||
strings.Contains(msg, "unauthorized_client") ||
strings.Contains(msg, "client_id")
}
// GetUserInfo 获取用户信息
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)

View File

@ -117,8 +117,7 @@ type GeminiToolConfig struct {
// GeminiFunctionCallingConfig 函数调用配置
type GeminiFunctionCallingConfig struct {
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE, ANY
AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
}
// GeminiSafetySetting Gemini 安全设置

View File

@ -9,7 +9,6 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strings"
"sync"
"time"
@ -23,22 +22,16 @@ const (
TokenURL = "https://oauth2.googleapis.com/token"
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// 个人账号 OAuth 凭证isGcpTos=false免费 Gemini Code Assist
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
// AntigravityOAuthClientSecretEnv 是个人账号 OAuth client_secret 的环境变量名。
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 企业账号 OAuth 凭证isGcpTos=trueGoogle Cloud / Workspace 用户)
EnterpriseClientID = "884354919052-36trc1jjb3tguiac32ov6cod268c5blh.apps.googleusercontent.com"
// AntigravityEnterpriseOAuthClientSecretEnv 是企业账号 OAuth client_secret 的环境变量名。
AntigravityEnterpriseOAuthClientSecretEnv = "ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri用户需手动复制 code
RedirectURI = "http://localhost:8085/callback"
// OAuth scopes(企业和个人共用)
// OAuth scopes
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
"https://www.googleapis.com/auth/userinfo.email " +
"https://www.googleapis.com/auth/userinfo.profile " +
@ -53,18 +46,15 @@ const (
// Antigravity API 端点
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.6product.json ideVersion
var defaultUserAgentVersion = "1.20.6"
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
var defaultUserAgentVersion = "1.21.9"
// defaultClientSecret 个人账号 client_secret可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
// defaultEnterpriseClientSecret 企业账号 client_secret可通过环境变量 ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET 覆盖
var defaultEnterpriseClientSecret = "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"
func init() {
// 从环境变量读取版本号,未设置则使用默认值
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
@ -74,58 +64,14 @@ func init() {
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
defaultClientSecret = secret
}
if secret := os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv); secret != "" {
defaultEnterpriseClientSecret = secret
}
}
// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为)
// GetUserAgent 返回当前配置的 User-Agent
func GetUserAgent() string {
return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH)
}
// ClientCredentials 持有一对 OAuth client_id/secret
type ClientCredentials struct {
ClientID string
ClientSecret string
}
// GetClientCredentials 根据账号类型返回对应的 OAuth 凭证。
// isEnterprise=true 时使用企业凭证isGcpTos=true否则使用个人凭证。
func GetClientCredentials(isEnterprise bool) (ClientCredentials, error) {
if isEnterprise {
secret := strings.TrimSpace(os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv))
if secret == "" {
secret = strings.TrimSpace(defaultEnterpriseClientSecret)
}
if secret == "" {
return ClientCredentials{}, infraerrors.Newf(http.StatusBadRequest,
"ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET_MISSING",
"missing enterprise oauth client_secret; set %s", AntigravityEnterpriseOAuthClientSecretEnv)
}
return ClientCredentials{ClientID: EnterpriseClientID, ClientSecret: secret}, nil
}
secret, err := getClientSecret()
if err != nil {
return ClientCredentials{}, err
}
return ClientCredentials{ClientID: ClientID, ClientSecret: secret}, nil
}
// BaseURLsForAccount 根据 isGcpTos 返回有序 URL 列表。
// 企业账号isGcpTos=true优先走 prod个人账号优先走 daily与真实 IDE 一致)。
func BaseURLsForAccount(isGcpTos bool) []string {
if isGcpTos {
return []string{antigravityProdBaseURL, antigravityDailyBaseURL}
}
return []string{antigravityDailyBaseURL, antigravityProdBaseURL}
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
}
func getClientSecret() (string, error) {
if secret := strings.TrimSpace(os.Getenv(AntigravityOAuthClientSecretEnv)); secret != "" {
defaultClientSecret = secret
return secret, nil
}
if v := strings.TrimSpace(defaultClientSecret); v != "" {
return v, nil
}
@ -265,7 +211,6 @@ type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
IsEnterprise bool `json:"is_enterprise,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
@ -380,15 +325,10 @@ func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// BuildAuthorizationURL 构建 Google OAuth 授权 URL。
// isEnterprise=true 时使用企业 client_id否则使用个人 client_id。
func BuildAuthorizationURL(state, codeChallenge string, isEnterprise bool) string {
clientID := ClientID
if isEnterprise {
clientID = EnterpriseClientID
}
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
func BuildAuthorizationURL(state, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", clientID)
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("response_type", "code")
params.Set("scope", Scopes)

View File

@ -1,19 +0,0 @@
package antigravity
import "testing"
func TestGetClientSecret_ReadsRuntimeEnvironment(t *testing.T) {
old := defaultClientSecret
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
t.Setenv(AntigravityOAuthClientSecretEnv, "runtime-secret")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("getClientSecret returned error: %v", err)
}
if secret != "runtime-secret" {
t.Fatalf("unexpected secret: got %q want %q", secret, "runtime-secret")
}
}

View File

@ -1,14 +1,12 @@
package antigravity
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"log"
"math/rand"
"regexp"
"strconv"
"strings"
"sync"
@ -18,16 +16,10 @@ import (
)
var (
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
sessionRandMutex sync.Mutex
legacyMetadataUserIDSessionPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[a-fA-F0-9-]*_session_([a-fA-F0-9-]{36})$`)
plainSessionIDPattern = regexp.MustCompile(`^(session_)?[a-fA-F0-9-]{36}$`)
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
sessionRandMutex sync.Mutex
)
type claudeMetadataUserIDPayload struct {
SessionID string `json:"session_id"`
}
// generateStableSessionID 基于用户消息内容生成稳定的 session ID
func generateStableSessionID(contents []GeminiContent) string {
// 查找第一个 user 消息的文本
@ -47,82 +39,12 @@ func generateStableSessionID(contents []GeminiContent) string {
return "-" + strconv.FormatInt(n, 10)
}
// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it.
// preferredSessionID wins; otherwise we derive a stable value from the first user turn.
func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" {
return body, nil
}
sessionID := strings.TrimSpace(preferredSessionID)
if sessionID == "" {
var req GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
sessionID = generateStableSessionID(req.Contents)
}
if sessionID == "" {
return body, nil
}
payload["sessionId"] = sessionID
return json.Marshal(payload)
}
func extractSessionIDFromMetadataUserID(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if strings.HasPrefix(raw, "{") {
var payload claudeMetadataUserIDPayload
if err := json.Unmarshal([]byte(raw), &payload); err == nil {
return strings.TrimSpace(payload.SessionID)
}
return ""
}
if matches := legacyMetadataUserIDSessionPattern.FindStringSubmatch(raw); len(matches) == 2 {
return strings.TrimSpace(matches[1])
}
if plainSessionIDPattern.MatchString(raw) {
return raw
}
return ""
}
func resolveClaudeRequestSessionID(metadata *ClaudeMetadata, preferredSessionID string, contents []GeminiContent) string {
if metadata != nil {
if sessionID := extractSessionIDFromMetadataUserID(metadata.UserID); sessionID != "" {
return sessionID
}
}
if sessionID := strings.TrimSpace(preferredSessionID); sessionID != "" {
return sessionID
}
return generateStableSessionID(contents)
}
type TransformOptions struct {
EnableIdentityPatch bool
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
IdentityPatch string
EnableMCPXML bool
// PreferredSessionID 可选:当 metadata.user_id 不可用于恢复真实会话时,
// 允许调用方显式指定 Antigravity 上游 request.sessionId。
PreferredSessionID string
}
func DefaultTransformOptions() TransformOptions {
@ -163,24 +85,12 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
normalizedReq, err := normalizeClaudeRequestForAntigravity(claudeReq)
if err != nil {
return nil, fmt.Errorf("normalize messages: %w", err)
}
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否有 web_search 工具
hasWebSearchTool := hasWebSearchTool(normalizedReq.Tools)
// requestType 映射策略:
// - Gemini 模型: "agent"(与 Antigravity 官方客户端一致)
// - Claude 模型: 不设置(避免 Google 后端路由到容量受限的 agent 池,降低 503 率)
// - web_search: "web_search"(触发 Google 搜索增强路由)
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
requestType := "agent"
if strings.HasPrefix(mappedModel, "claude-") {
requestType = "" // Claude 模型走默认容量池,避免 agent 池 503
}
targetModel := mappedModel
if hasWebSearchTool {
requestType = "web_search"
@ -190,27 +100,27 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
// 检测是否启用 thinking
isThinkingEnabled := normalizedReq.Thinking != nil && (normalizedReq.Thinking.Type == "enabled" || normalizedReq.Thinking.Type == "adaptive")
isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
// 1. 构建 contents
contents, strippedThinking, err := buildContents(normalizedReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
return nil, fmt.Errorf("build contents: %w", err)
}
// 2. 构建 systemInstruction使用 targetModel 而非原始请求模型,确保身份注入基于最终模型)
systemInstruction := buildSystemInstruction(normalizedReq.System, targetModel, opts, normalizedReq.Tools)
systemInstruction := buildSystemInstruction(claudeReq.System, targetModel, opts, claudeReq.Tools)
// 3. 构建 generationConfig
reqForConfig := normalizedReq
reqForConfig := claudeReq
if strippedThinking {
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
// disable upstream thinking mode to avoid signature/structure validation errors.
reqCopy := *normalizedReq
reqCopy := *claudeReq
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
@ -222,24 +132,19 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
// 对 Claude / Gemini 模型都保留 functionDeclarations
// - Claude 分支如果完全丢掉 tools模型只能看到消息历史中的 tool_use/tool_result
// 但拿不到当前可用工具定义,容易导致“能还原名字但不会继续发工具调用”。
// - Gemini 分支原本就依赖 functionDeclarations 触发 function_call。
isClaudeModel := strings.HasPrefix(targetModel, "claude-")
tools := buildTools(normalizedReq.Tools)
tools := buildTools(claudeReq.Tools)
// 5. 构建内部请求
innerRequest := GeminiRequest{
Contents: contents,
SessionID: resolveClaudeRequestSessionID(normalizedReq.Metadata, opts.PreferredSessionID, contents),
}
// Gemini 分支保持默认 VALIDATED
// Claude 分支仅在声明了工具时附带 toolConfig避免再把工具能力静默丢失。
defaultValidated := !isClaudeModel || len(tools) > 0
if toolConfig := buildToolConfig(normalizedReq.ToolChoice, defaultValidated); toolConfig != nil {
innerRequest.ToolConfig = toolConfig
Contents: contents,
// 总是设置 toolConfig与官方客户端一致
ToolConfig: &GeminiToolConfig{
FunctionCallingConfig: &GeminiFunctionCallingConfig{
Mode: "VALIDATED",
},
},
// 总是生成 sessionId基于用户消息内容
SessionID: generateStableSessionID(contents),
}
if systemInstruction != nil {
@ -252,6 +157,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
innerRequest.Tools = tools
}
// 如果提供了 metadata.user_id优先使用
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
innerRequest.SessionID = claudeReq.Metadata.UserID
}
// 6. 包装为 v1internal 请求
v1Req := V1InternalRequest{
Project: projectID,
@ -265,319 +175,6 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
return json.Marshal(v1Req)
}
const (
maxAntigravityToolDescriptionChars = 400
maxAntigravitySchemaDescriptionChars = 200
maxAntigravityToolResultChars = 200000
)
func normalizeClaudeRequestForAntigravity(claudeReq *ClaudeRequest) (*ClaudeRequest, error) {
if claudeReq == nil {
return nil, nil
}
reqCopy := *claudeReq
if len(claudeReq.Messages) == 0 {
return &reqCopy, nil
}
normalizedMessages, err := normalizeClaudeMessagesForAntigravity(claudeReq.Messages)
if err != nil {
return nil, err
}
reqCopy.Messages = normalizedMessages
return &reqCopy, nil
}
func normalizeClaudeMessagesForAntigravity(messages []ClaudeMessage) ([]ClaudeMessage, error) {
normalized := make([]ClaudeMessage, 0, len(messages)+1)
pendingToolUseIDs := make([]string, 0)
for _, message := range messages {
blocks, hasBlocks := parseClaudeMessageBlocks(message.Content)
switch message.Role {
case "assistant":
if len(pendingToolUseIDs) > 0 {
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
if err != nil {
return nil, err
}
normalized = append(normalized, synthetic)
pendingToolUseIDs = pendingToolUseIDs[:0]
}
if !hasBlocks {
normalized = append(normalized, cloneClaudeMessage(message))
continue
}
stripped := stripNonToolPartsAfterToolUse(reorderAssistantThinkingBlocks(blocks))
pendingToolUseIDs = append(pendingToolUseIDs, collectToolUseIDs(stripped)...)
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, stripped)
if err != nil {
return nil, err
}
normalized = append(normalized, nextMessage)
case "user":
if !hasBlocks {
if len(pendingToolUseIDs) > 0 {
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
if err != nil {
return nil, err
}
normalized = append(normalized, synthetic)
pendingToolUseIDs = pendingToolUseIDs[:0]
}
normalized = append(normalized, cloneClaudeMessage(message))
continue
}
parts := cloneJSONBlocks(blocks)
if len(pendingToolUseIDs) > 0 {
toolResults, nonToolResults := partitionToolResultBlocks(parts)
existingIDs := collectToolResultIDs(toolResults)
missingIDs := diffStringSlice(pendingToolUseIDs, existingIDs)
if len(missingIDs) > 0 {
parts = append(append(toolResults, buildSyntheticToolResultBlocks(missingIDs)...), nonToolResults...)
}
pendingToolUseIDs = pendingToolUseIDs[:0]
}
toolResults, nonToolResults := partitionToolResultBlocks(parts)
switch {
case len(toolResults) == 0:
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, parts)
if err != nil {
return nil, err
}
normalized = append(normalized, nextMessage)
case len(nonToolResults) == 0:
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults)
if err != nil {
return nil, err
}
normalized = append(normalized, nextMessage)
default:
toolResultMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults)
if err != nil {
return nil, err
}
userTextMessage, err := buildClaudeMessageWithBlocks(message.Role, nonToolResults)
if err != nil {
return nil, err
}
normalized = append(normalized, toolResultMessage, userTextMessage)
}
default:
normalized = append(normalized, cloneClaudeMessage(message))
}
}
if len(pendingToolUseIDs) > 0 {
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
if err != nil {
return nil, err
}
normalized = append(normalized, synthetic)
}
return normalized, nil
}
func parseClaudeMessageBlocks(content json.RawMessage) ([]map[string]any, bool) {
var blocks []map[string]any
if err := json.Unmarshal(content, &blocks); err != nil {
return nil, false
}
return blocks, true
}
func cloneClaudeMessage(message ClaudeMessage) ClaudeMessage {
cloned := ClaudeMessage{Role: message.Role}
if len(message.Content) > 0 {
cloned.Content = append(json.RawMessage(nil), message.Content...)
}
return cloned
}
func cloneJSONBlocks(blocks []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(blocks))
for _, block := range blocks {
cloned = append(cloned, cloneJSONMap(block))
}
return cloned
}
func cloneJSONMap(block map[string]any) map[string]any {
if block == nil {
return nil
}
if cloned, ok := deepCopy(block).(map[string]any); ok {
return cloned
}
fallback := make(map[string]any, len(block))
for key, value := range block {
fallback[key] = value
}
return fallback
}
func buildClaudeMessageWithBlocks(role string, blocks []map[string]any) (ClaudeMessage, error) {
payload, err := json.Marshal(blocks)
if err != nil {
return ClaudeMessage{}, fmt.Errorf("marshal %s message blocks: %w", role, err)
}
return ClaudeMessage{Role: role, Content: payload}, nil
}
func buildSyntheticToolResultMessage(toolUseIDs []string) (ClaudeMessage, error) {
return buildClaudeMessageWithBlocks("user", buildSyntheticToolResultBlocks(toolUseIDs))
}
func buildSyntheticToolResultBlocks(toolUseIDs []string) []map[string]any {
blocks := make([]map[string]any, 0, len(toolUseIDs))
for _, toolUseID := range toolUseIDs {
if strings.TrimSpace(toolUseID) == "" {
continue
}
blocks = append(blocks, map[string]any{
"type": "tool_result",
"tool_use_id": toolUseID,
"is_error": true,
"content": []map[string]any{
{
"type": "text",
"text": "[tool_result missing; tool execution interrupted]",
},
},
})
}
return blocks
}
func reorderAssistantThinkingBlocks(blocks []map[string]any) []map[string]any {
thinkingBlocks := make([]map[string]any, 0)
otherBlocks := make([]map[string]any, 0, len(blocks))
for _, block := range blocks {
cloned := cloneJSONMap(block)
blockType, _ := cloned["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
delete(cloned, "cache_control")
thinkingBlocks = append(thinkingBlocks, cloned)
continue
}
otherBlocks = append(otherBlocks, cloned)
}
if len(thinkingBlocks) == 0 {
return otherBlocks
}
return append(thinkingBlocks, otherBlocks...)
}
func stripNonToolPartsAfterToolUse(blocks []map[string]any) []map[string]any {
cleaned := make([]map[string]any, 0, len(blocks))
seenToolUse := false
for _, block := range blocks {
blockType, _ := block["type"].(string)
if blockType == "tool_use" {
seenToolUse = true
cleaned = append(cleaned, block)
continue
}
if !seenToolUse {
cleaned = append(cleaned, block)
continue
}
if isIgnorableTrailingTextBlock(block) {
continue
}
}
return cleaned
}
func isIgnorableTrailingTextBlock(block map[string]any) bool {
blockType, _ := block["type"].(string)
if blockType != "text" {
return false
}
text, _ := block["text"].(string)
trimmed := strings.TrimSpace(text)
return trimmed == "" || trimmed == "(no content)"
}
func collectToolUseIDs(blocks []map[string]any) []string {
ids := make([]string, 0)
for _, block := range blocks {
blockType, _ := block["type"].(string)
if blockType != "tool_use" {
continue
}
id, _ := block["id"].(string)
if strings.TrimSpace(id) != "" {
ids = append(ids, id)
}
}
return ids
}
func collectToolResultIDs(blocks []map[string]any) []string {
ids := make([]string, 0, len(blocks))
for _, block := range blocks {
id, _ := block["tool_use_id"].(string)
if strings.TrimSpace(id) != "" {
ids = append(ids, id)
}
}
return ids
}
func diffStringSlice(left, right []string) []string {
if len(left) == 0 {
return nil
}
seen := make(map[string]struct{}, len(right))
for _, value := range right {
if strings.TrimSpace(value) != "" {
seen[value] = struct{}{}
}
}
diff := make([]string, 0, len(left))
for _, value := range left {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
diff = append(diff, value)
}
return diff
}
func partitionToolResultBlocks(blocks []map[string]any) (toolResults []map[string]any, nonToolResults []map[string]any) {
toolResults = make([]map[string]any, 0)
nonToolResults = make([]map[string]any, 0)
for _, block := range blocks {
blockType, _ := block["type"].(string)
if blockType == "tool_result" {
toolResults = append(toolResults, block)
continue
}
nonToolResults = append(nonToolResults, block)
}
return toolResults, nonToolResults
}
// antigravityIdentity Antigravity identity 提示词
const antigravityIdentity = `<identity>
You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.
@ -877,14 +474,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
case "tool_use":
// 存储 id -> name 映射
toolName := normalizeClaudeCodeToolName(block.Name)
if block.ID != "" && toolName != "" {
toolIDToName[block.ID] = toolName
if block.ID != "" && block.Name != "" {
toolIDToName[block.ID] = block.Name
}
part := GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: toolName,
Name: block.Name,
Args: block.Input,
ID: block.ID,
},
@ -901,12 +497,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
case "tool_result":
// 获取函数名
funcName := normalizeClaudeCodeToolName(block.Name)
funcName := block.Name
if funcName == "" {
if name, ok := toolIDToName[block.ToolUseID]; ok {
funcName = name
} else {
funcName = normalizeClaudeCodeToolName(block.ToolUseID)
funcName = block.ToolUseID
}
}
@ -929,84 +525,47 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
// parseToolResultContent 解析 tool_result 的 content
func parseToolResultContent(content json.RawMessage, isError bool) any {
func parseToolResultContent(content json.RawMessage, isError bool) string {
if len(content) == 0 {
return defaultToolResultContent(isError)
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(content, &str); err == nil {
if strings.TrimSpace(str) == "" {
return defaultToolResultContent(isError)
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
return truncateInlineText(str, maxAntigravityToolResultChars)
return str
}
// 优先保留结构化 tool_result避免上游把内容视为无效的纯文本降级。
// 尝试解析为数组
var arr []map[string]any
if err := json.Unmarshal(content, &arr); err == nil {
sanitized := sanitizeToolResultBlocksForAntigravity(arr)
if len(sanitized) == 0 {
return defaultToolResultContent(isError)
var texts []string
for _, item := range arr {
if text, ok := item["text"].(string); ok {
texts = append(texts, text)
}
}
return sanitized
}
var obj map[string]any
if err := json.Unmarshal(content, &obj); err == nil {
sanitized := sanitizeToolResultObjectForAntigravity(obj)
if len(sanitized) == 0 {
return defaultToolResultContent(isError)
result := strings.Join(texts, "\n")
if strings.TrimSpace(result) == "" {
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
return sanitized
return result
}
// 返回原始 JSON
return truncateInlineText(string(content), maxAntigravityToolResultChars)
}
func defaultToolResultContent(isError bool) string {
if isError {
return "Tool execution failed with no output."
}
return "Command executed successfully."
}
func sanitizeToolResultBlocksForAntigravity(blocks []map[string]any) []map[string]any {
sanitized := make([]map[string]any, 0, len(blocks))
for _, block := range blocks {
if isBase64ImageToolResultBlock(block) {
continue
}
cloned := cloneJSONMap(block)
if text, ok := cloned["text"].(string); ok {
cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars)
}
sanitized = append(sanitized, cloned)
}
return sanitized
}
func sanitizeToolResultObjectForAntigravity(block map[string]any) map[string]any {
if isBase64ImageToolResultBlock(block) {
return nil
}
cloned := cloneJSONMap(block)
if text, ok := cloned["text"].(string); ok {
cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars)
}
return cloned
}
func isBase64ImageToolResultBlock(block map[string]any) bool {
blockType, _ := block["type"].(string)
if blockType != "image" {
return false
}
source, _ := block["source"].(map[string]any)
sourceType, _ := source["type"].(string)
return sourceType == "base64"
return string(content)
}
// buildGenerationConfig 构建 generationConfig
@ -1074,15 +633,6 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
}
config.ThinkingConfig.ThinkingBudget = budget
} else if strings.HasSuffix(req.Model, "-thinking") || strings.HasPrefix(req.Model, "claude-sonnet-4-6") {
// 自动注入 thinkingConfig 的两种情形(客户端未显式开启 thinking
// 1. 模型名以 -thinking 结尾(如 claude-opus-4-6-thinkingGoogle 要求此后缀模型必须携带 thinkingConfig。
// 2. claude-sonnet-4-6无 -thinking 变体404但模型本身要求携带 thinkingConfigbudget 必须为 -1动态
// 注:固定 budget如 1024在 max_tokens 较小时会触发 400max_tokens 必须大于 budget
config.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: -1, // 动态预算,避免 max_tokens vs budget 冲突
}
}
if config.MaxOutputTokens > maxLimit {
@ -1126,65 +676,6 @@ func isWebSearchTool(tool ClaudeTool) bool {
}
}
func buildToolConfig(toolChoice json.RawMessage, defaultValidated bool) *GeminiToolConfig {
raw := bytes.TrimSpace(toolChoice)
if len(raw) == 0 {
if !defaultValidated {
return nil
}
return &GeminiToolConfig{
FunctionCallingConfig: &GeminiFunctionCallingConfig{
Mode: "VALIDATED",
},
}
}
choiceType := ""
toolName := ""
if len(raw) > 0 && raw[0] == '"' {
var choice string
if err := json.Unmarshal(raw, &choice); err == nil {
choiceType = strings.TrimSpace(choice)
}
} else {
var choice map[string]any
if err := json.Unmarshal(raw, &choice); err == nil {
if value, ok := choice["type"].(string); ok {
choiceType = strings.TrimSpace(value)
}
if value, ok := choice["name"].(string); ok {
toolName = normalizeClaudeCodeToolName(value)
}
}
}
mode := ""
switch strings.ToLower(choiceType) {
case "auto":
mode = "AUTO"
case "none":
mode = "NONE"
case "any", "required":
mode = "ANY"
case "tool":
mode = "ANY"
case "validated":
mode = "VALIDATED"
default:
if !defaultValidated {
return nil
}
mode = "VALIDATED"
}
cfg := &GeminiFunctionCallingConfig{Mode: mode}
if toolName != "" && mode == "ANY" {
cfg.AllowedFunctionNames = []string{toolName}
}
return &GeminiToolConfig{FunctionCallingConfig: cfg}
}
// buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 {
@ -1215,12 +706,12 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
continue
}
description = tool.Custom.Description
inputSchema = cloneStringAnyMap(tool.Custom.InputSchema)
inputSchema = tool.Custom.InputSchema
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
inputSchema = cloneStringAnyMap(tool.InputSchema)
inputSchema = tool.InputSchema
}
// 清理 JSON Schema
@ -1235,11 +726,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
"properties": map[string]any{},
}
}
description = compactToolDescriptionForAntigravity(description)
params = compactSchemaDescriptionsForAntigravity(params)
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: normalizeClaudeCodeToolName(tool.Name),
Name: tool.Name,
Description: description,
Parameters: params,
})
@ -1268,64 +757,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
return declarations
}
func cloneStringAnyMap(input map[string]any) map[string]any {
if input == nil {
return nil
}
if cloned, ok := deepCopy(input).(map[string]any); ok {
return cloned
}
fallback := make(map[string]any, len(input))
for key, value := range input {
fallback[key] = value
}
return fallback
}
func compactToolDescriptionForAntigravity(description string) string {
if strings.TrimSpace(description) == "" {
return ""
}
lines := strings.Split(strings.ReplaceAll(description, "\r\n", "\n"), "\n")
compacted := make([]string, 0, len(lines))
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
compacted = append(compacted, line)
if len(compacted) == 6 {
break
}
}
return truncateInlineText(strings.Join(compacted, " "), maxAntigravityToolDescriptionChars)
}
func compactSchemaDescriptionsForAntigravity(schema map[string]any) map[string]any {
for key, value := range schema {
switch typed := value.(type) {
case string:
if key == "description" {
schema[key] = truncateInlineText(strings.Join(strings.Fields(typed), " "), maxAntigravitySchemaDescriptionChars)
}
case map[string]any:
schema[key] = compactSchemaDescriptionsForAntigravity(typed)
case []any:
for i, item := range typed {
if nested, ok := item.(map[string]any); ok {
typed[i] = compactSchemaDescriptionsForAntigravity(nested)
}
}
schema[key] = typed
}
}
return schema
}
func truncateInlineText(text string, maxChars int) string {
if maxChars <= 0 || len(text) <= maxChars {
return text
}
return text[:maxChars] + "...[truncated " + strconv.Itoa(len(text)-maxChars) + " chars]"
}

View File

@ -8,112 +8,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestEnsureGeminiRequestSessionID(t *testing.T) {
t.Run("prefers provided session id", func(t *testing.T) {
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
require.NoError(t, err)
var payload map[string]any
require.NoError(t, json.Unmarshal(updated, &payload))
require.Equal(t, "session-from-header", payload["sessionId"])
})
t.Run("keeps existing session id", func(t *testing.T) {
body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
require.NoError(t, err)
var payload map[string]any
require.NoError(t, json.Unmarshal(updated, &payload))
require.Equal(t, "session-in-body", payload["sessionId"])
})
t.Run("derives stable fallback from contents", func(t *testing.T) {
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
first, err := EnsureGeminiRequestSessionID(body, "")
require.NoError(t, err)
second, err := EnsureGeminiRequestSessionID(body, "")
require.NoError(t, err)
var firstPayload map[string]any
var secondPayload map[string]any
require.NoError(t, json.Unmarshal(first, &firstPayload))
require.NoError(t, json.Unmarshal(second, &secondPayload))
require.NotEmpty(t, firstPayload["sessionId"])
require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"])
})
}
func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDJSON(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
Metadata: &ClaudeMetadata{
UserID: `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", req.Request.SessionID)
}
func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDLegacy(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
Metadata: &ClaudeMetadata{
UserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000",
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", req.Request.SessionID)
}
func TestTransformClaudeToGeminiWithOptions_PrefersExplicitSessionWhenMetadataIsNotSessionPayload(t *testing.T) {
opts := DefaultTransformOptions()
opts.PreferredSessionID = "session-header-1"
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
Metadata: &ClaudeMetadata{
UserID: "custom-user-42",
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", opts)
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Equal(t, "session-header-1", req.Request.SessionID)
}
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
@ -436,36 +330,16 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
wantPresent: true,
},
{
// Google v1internal 要求 -thinking 模型必须携带 thinkingConfig即使客户端明确 disabled。
// 不携带会导致 Google 立即返回错误(在生产中表现为快速 503
name: "disabled on -thinking model auto-injects thinkingConfig (Google requires it)",
name: "disabled does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
wantBudget: -1, // auto-injected dynamic budget
wantPresent: true,
wantBudget: 0,
wantPresent: false,
},
{
// Google v1internal 要求 -thinking 模型必须携带 thinkingConfignil 时自动注入。
name: "nil thinking on -thinking model auto-injects thinkingConfig (Google requires it)",
name: "nil thinking does not emit thinkingConfig",
model: "claude-opus-4-6-thinking",
thinking: nil,
wantBudget: -1, // auto-injected dynamic budget
wantPresent: true,
},
{
// claude-sonnet-4-6 需要 thinkingConfig无 -thinking 变体budget 必须为 -1动态
// 经测试claude-sonnet-4-6-thinking → 404claude-sonnet-4-6 + budget=-1 → 200 OK
name: "nil thinking on claude-sonnet-4-6 auto-injects thinkingConfig (no -thinking variant exists)",
model: "claude-sonnet-4-6",
thinking: nil,
wantBudget: -1,
wantPresent: true,
},
{
// 非 -thinking 普通模型(如 claude-opus-4-6服务层已转为 -thinking此处测试原始名
name: "nil thinking on plain non-thinking model does not emit thinkingConfig",
model: "claude-opus-4-6",
thinking: nil,
wantBudget: 0,
wantPresent: false,
},
@ -582,214 +456,3 @@ func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions
require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name)
require.NotNil(t, req.Request.Tools[1].GoogleSearch)
}
func TestTransformClaudeToGeminiWithOptions_ClaudeModelKeepsToolsAndValidatedToolConfig(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"read the file"}]`),
},
},
Tools: []ClaudeTool{
{
Name: "read_file",
Description: "Read a file",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"file_path": map[string]any{"type": "string"},
},
},
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Len(t, req.Request.Tools, 1)
require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1)
require.Equal(t, "Read", req.Request.Tools[0].FunctionDeclarations[0].Name)
require.NotNil(t, req.Request.ToolConfig)
require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig)
require.Equal(t, "VALIDATED", req.Request.ToolConfig.FunctionCallingConfig.Mode)
}
func TestTransformClaudeToGeminiWithOptions_ClaudeModelToolChoiceSpecificTool(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
ToolChoice: json.RawMessage(`{"type":"tool","name":"search_files"}`),
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"find todo"}]`),
},
},
Tools: []ClaudeTool{
{
Name: "search_files",
Description: "Search files",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{"type": "string"},
},
},
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.NotNil(t, req.Request.ToolConfig)
require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig)
require.Equal(t, "ANY", req.Request.ToolConfig.FunctionCallingConfig.Mode)
require.Equal(t, []string{"Grep"}, req.Request.ToolConfig.FunctionCallingConfig.AllowedFunctionNames)
}
func TestTransformClaudeToGeminiWithOptions_NormalizesInterruptedToolHistory(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-sonnet-4-5",
Messages: []ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"tool_use","id":"tool-1","name":"Bash","input":{"command":"pwd"}},
{"type":"text","text":"(no content)"}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"继续"}]`),
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Len(t, req.Request.Contents, 3)
first := req.Request.Contents[0]
require.Equal(t, "model", first.Role)
require.Len(t, first.Parts, 1)
require.NotNil(t, first.Parts[0].FunctionCall)
require.Equal(t, "tool-1", first.Parts[0].FunctionCall.ID)
second := req.Request.Contents[1]
require.Equal(t, "user", second.Role)
require.Len(t, second.Parts, 1)
require.NotNil(t, second.Parts[0].FunctionResponse)
require.Equal(t, "tool-1", second.Parts[0].FunctionResponse.ID)
resultBlocks, ok := second.Parts[0].FunctionResponse.Response["result"].([]any)
require.True(t, ok)
require.Len(t, resultBlocks, 1)
resultBlock, ok := resultBlocks[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", resultBlock["type"])
require.Equal(t, "[tool_result missing; tool execution interrupted]", resultBlock["text"])
third := req.Request.Contents[2]
require.Equal(t, "user", third.Role)
require.Len(t, third.Parts, 1)
require.Equal(t, "继续", third.Parts[0].Text)
}
func TestNormalizeClaudeMessagesForAntigravity_ReordersThinkingAndSplitsToolResult(t *testing.T) {
messages := []ClaudeMessage{
{
Role: "assistant",
Content: json.RawMessage(`[
{"type":"text","text":"before"},
{"type":"thinking","thinking":"deep thought","signature":"sig-1"},
{"type":"tool_use","id":"tool-2","name":"Bash","input":{"command":"ls"}},
{"type":"text","text":"(no content)"}
]`),
},
{
Role: "user",
Content: json.RawMessage(`[
{"type":"tool_result","tool_use_id":"tool-2","content":[{"type":"text","text":"ok"}]},
{"type":"text","text":"下一步"}
]`),
},
}
normalized, err := normalizeClaudeMessagesForAntigravity(messages)
require.NoError(t, err)
require.Len(t, normalized, 3)
var assistantBlocks []map[string]any
require.NoError(t, json.Unmarshal(normalized[0].Content, &assistantBlocks))
require.Len(t, assistantBlocks, 3)
require.Equal(t, "thinking", assistantBlocks[0]["type"])
require.Equal(t, "text", assistantBlocks[1]["type"])
require.Equal(t, "tool_use", assistantBlocks[2]["type"])
var toolResultBlocks []map[string]any
require.NoError(t, json.Unmarshal(normalized[1].Content, &toolResultBlocks))
require.Len(t, toolResultBlocks, 1)
require.Equal(t, "tool_result", toolResultBlocks[0]["type"])
var userTextBlocks []map[string]any
require.NoError(t, json.Unmarshal(normalized[2].Content, &userTextBlocks))
require.Len(t, userTextBlocks, 1)
require.Equal(t, "text", userTextBlocks[0]["type"])
require.Equal(t, "下一步", userTextBlocks[0]["text"])
}
func TestParseToolResultContent_PreservesStructuredBlocks(t *testing.T) {
content := json.RawMessage(`[
{"type":"text","text":"hello"},
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
]`)
result := parseToolResultContent(content, false)
blocks, ok := result.([]map[string]any)
require.True(t, ok)
require.Len(t, blocks, 1)
require.Equal(t, "text", blocks[0]["type"])
require.Equal(t, "hello", blocks[0]["text"])
}
func TestBuildTools_CompactsDescriptions(t *testing.T) {
longLine := strings.Repeat("schema detail ", 40)
result := buildTools([]ClaudeTool{
{
Name: "describe",
Description: strings.Repeat("tool description\n", 20),
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": longLine,
},
},
},
},
})
require.Len(t, result, 1)
require.Len(t, result[0].FunctionDeclarations, 1)
decl := result[0].FunctionDeclarations[0]
require.LessOrEqual(t, len(decl.Description), maxAntigravityToolDescriptionChars+32)
props, ok := decl.Parameters["properties"].(map[string]any)
require.True(t, ok)
query, ok := props["query"].(map[string]any)
require.True(t, ok)
description, ok := query["description"].(string)
require.True(t, ok)
require.LessOrEqual(t, len(description), maxAntigravitySchemaDescriptionChars+32)
}

View File

@ -121,20 +121,17 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.hasToolCall = true
toolName := normalizeClaudeCodeToolName(part.FunctionCall.Name)
// 生成 tool_use id
toolID := part.FunctionCall.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID())
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
}
item := ClaudeContentItem{
Type: "tool_use",
ID: toolID,
Name: toolName,
Input: part.FunctionCall.Args,
Caller: &ToolCaller{Type: "direct"},
Type: "tool_use",
ID: toolID,
Name: part.FunctionCall.Name,
Input: part.FunctionCall.Args,
}
if signature != "" {

View File

@ -362,21 +362,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu
var result bytes.Buffer
p.usedTool = true
toolName := normalizeClaudeCodeToolName(fc.Name)
toolID := fc.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID())
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
}
toolUse := map[string]any{
"type": "tool_use",
"id": toolID,
"name": toolName,
"name": fc.Name,
"input": map[string]any{},
"caller": map[string]any{
"type": "direct",
},
}
if signature != "" {

View File

@ -27,7 +27,7 @@ func stubUserHome(t *testing.T, home string) {
func TestDiscoverBinary_EnvOverrideWins(t *testing.T) {
stubStatFn(t, map[string]bool{
"/tmp/my-ls": true,
"/tmp/my-ls": true,
"/opt/windsurf/language_server_linux_x64": true, // should not be picked
})
got, err := discoverBinaryFor(Platform{"linux", "amd64"}, "/tmp/my-ls", "/opt/windsurf/language_server_linux_x64")

View File

@ -39,8 +39,6 @@ func ProvideRouter(
opsService *service.OpsService,
settingService *service.SettingService,
redisClient *redis.Client,
langServerService *service.LanguageServerService,
lsrpcHandler *service.LSRPCHandler,
) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
@ -97,7 +95,7 @@ func ProvideRouter(
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
})
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler)
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
// ProvideHTTPServer 提供 HTTP 服务器

View File

@ -7,7 +7,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/routes"
@ -33,8 +32,6 @@ func SetupRouter(
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
langServerService *service.LanguageServerService,
lsrpcHandler *service.LSRPCHandler,
) *gin.Engine {
// 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src
var cachedFrameOrigins atomic.Pointer[[]string]
@ -84,7 +81,7 @@ func SetupRouter(
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler)
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
return r
}
@ -102,8 +99,6 @@ func registerRoutes(
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
langServerService *service.LanguageServerService,
lsrpcHandler *service.LSRPCHandler,
) {
// 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r)
@ -120,15 +115,5 @@ func registerRoutes(
// Windsurf gateway routes
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
// 注册 Antigravity HTTP API 路由
routes.RegisterAntigravityHTTPRoutes(v1, langServerService)
// 挂载 connectrpc LanguageServerService 路由
// Claude Code 客户端通过 /exa.language_server_pb.LanguageServerService/* 路径访问
if lsrpcHandler != nil {
lsrpcPath, lsrpcHTTPHandler := language_server_pbconnect.NewLanguageServerServiceHandler(lsrpcHandler)
r.Any(lsrpcPath+"*action", gin.WrapH(lsrpcHTTPHandler))
}
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
}

View File

@ -1,192 +0,0 @@
package routes
import (
"encoding/json"
"log/slog"
"net/http"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterAntigravityHTTPRoutes 注册 Antigravity HTTP API 路由
func RegisterAntigravityHTTPRoutes(v1 *gin.RouterGroup, langServerService *service.LanguageServerService) {
logger := slog.Default()
// 创建处理器
cascadeGroup := v1.Group("/cascade")
{
// 启动 Cascade 会话
cascadeGroup.POST("/start", func(c *gin.Context) {
handleStartCascade(c, langServerService, logger)
})
// 发送消息到 Cascade流式响应
cascadeGroup.POST("/message", func(c *gin.Context) {
handleSendMessage(c, langServerService, logger)
})
// 取消 Cascade 会话
cascadeGroup.POST("/cancel", func(c *gin.Context) {
handleCancelCascade(c, langServerService, logger)
})
}
// 模型列表
v1.GET("/models", func(c *gin.Context) {
handleGetModels(c, langServerService, logger)
})
// 健康检查
v1.GET("/health", func(c *gin.Context) {
handleHealth(c, logger)
})
}
// handleStartCascade 处理启动 Cascade 请求
func handleStartCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
type StartCascadeRequest struct {
Model string `json:"model" binding:"required"`
SystemPrompt string `json:"system_prompt"`
Metadata map[string]string `json:"metadata"`
}
var req StartCascadeRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error("invalid start cascade request", "error", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
// 获取 OAuth token
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
return
}
// 调用服务
cascadeID, err := svc.StartCascade(
c.Request.Context(),
req.Model,
req.SystemPrompt,
req.Metadata,
token,
)
if err != nil {
logger.Error("start cascade failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"cascade_id": cascadeID})
}
// handleSendMessage 处理发送消息请求(流式)
func handleSendMessage(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
type SendMessageRequest struct {
CascadeID string `json:"cascade_id" binding:"required"`
Message string `json:"message" binding:"required"`
}
var req SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error("invalid send message request", "error", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
// 获取 OAuth token
token := c.GetHeader("Authorization")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
return
}
// 调用服务并获取流式更新通道
updateChan, err := svc.SendUserMessage(c.Request.Context(), req.CascadeID, req.Message, token)
if err != nil {
logger.Error("send message failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
// 流式发送更新到客户端
flusher, ok := c.Writer.(http.Flusher)
if !ok {
logger.Error("response writer does not support flushing")
return
}
for event := range updateChan {
if event == nil {
break
}
// 将事件序列化为 JSON
eventJSON, err := marshalJSON(event)
if err != nil {
logger.Error("failed to marshal event", "error", err)
continue
}
// 发送 SSE 格式的数据
_, _ = c.Writer.WriteString("data: " + string(eventJSON) + "\n\n")
flusher.Flush()
}
}
// handleCancelCascade 处理取消 Cascade 请求
func handleCancelCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
type CancelRequest struct {
CascadeID string `json:"cascade_id" binding:"required"`
}
var req CancelRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error("invalid cancel request", "error", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
err := svc.CancelCascade(c.Request.Context(), req.CascadeID)
if err != nil {
logger.Error("cancel cascade failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "cascade cancelled"})
}
// handleGetModels 处理获取模型列表请求
func handleGetModels(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
models, err := svc.GetAvailableModels(c.Request.Context())
if err != nil {
logger.Error("get models failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"models": models,
"default_model": "claude-opus-4-6",
})
}
// handleHealth 处理健康检查请求
func handleHealth(c *gin.Context, logger *slog.Logger) {
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
}
// marshalJSON 辅助函数用于序列化事件
func marshalJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

View File

@ -1,365 +0,0 @@
package routes
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"log/slog"
)
func TestAntigravityHTTPRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
// 创建模拟的 LanguageServerService
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
defer mockService.Stop()
// 创建路由
r := gin.New()
v1 := r.Group("/api/v1")
// 注册 Antigravity 路由
RegisterAntigravityHTTPRoutes(v1, mockService)
// 测试 1: GET /health
t.Run("HealthCheck", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var result map[string]string
json.Unmarshal(w.Body.Bytes(), &result)
if result["status"] != "healthy" {
t.Fatalf("Expected status=healthy, got %v", result)
}
t.Log("✅ 健康检查端点")
})
// 测试 2: GET /models
t.Run("GetModels", func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/models", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var result map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &result)
if result["default_model"] != "claude-opus-4-6" {
t.Fatalf("Expected default_model, got %v", result)
}
t.Log("✅ 获取模型列表")
})
// 测试 3: POST /cascade/start
var cascadeID string
t.Run("StartCascade", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{
"model": "claude-opus-4-6",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var result map[string]string
json.Unmarshal(w.Body.Bytes(), &result)
cascadeID = result["cascade_id"]
if cascadeID == "" {
t.Fatalf("Expected cascade_id, got empty")
}
t.Logf("✅ 启动会话 (cascade_id=%s)", cascadeID)
})
// 测试 4: POST /cascade/cancel使用从第3个测试获取的真实会话ID
t.Run("CancelCascade", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{
"cascade_id": cascadeID,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/cancel", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
var result map[string]string
json.Unmarshal(w.Body.Bytes(), &result)
if result["message"] != "cascade cancelled" {
t.Fatalf("Expected cascade cancelled message, got %v", result)
}
t.Log("✅ 取消会话")
})
// 测试 5: POST /cascade/message (SSE) - 验证响应头格式
t.Run("SendMessage", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{
"cascade_id": cascadeID,
"message": "Hello, world!",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected 200, got %d", w.Code)
}
contentType := w.Header().Get("Content-Type")
if contentType != "text/event-stream" {
t.Fatalf("Expected text/event-stream, got %s", contentType)
}
t.Log("✅ 发送消息SSE流式响应")
})
t.Log("\n✅ 所有 Antigravity HTTP API 路由测试通过!")
}
func TestStartCascadeValidation(t *testing.T) {
gin.SetMode(gin.TestMode)
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
defer mockService.Stop()
r := gin.New()
v1 := r.Group("/api/v1")
RegisterAntigravityHTTPRoutes(v1, mockService)
t.Run("MissingModel", func(t *testing.T) {
w := httptest.NewRecorder()
body := []byte(`{"system_prompt":"test"}`)
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected 400 for invalid request, got %d", w.Code)
}
t.Log("✅ 缺少必需字段验证")
})
t.Run("MissingAuthorization", func(t *testing.T) {
w := httptest.NewRecorder()
body := []byte(`{"model":"claude-opus-4-6"}`)
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
// 不设置 Authorization 头
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected 401 for missing auth, got %d", w.Code)
}
t.Log("✅ 缺少授权令牌验证")
})
t.Log("\n✅ 所有验证测试通过!")
}
// TestRateLimiting 测试速率限制(改进 1
func TestRateLimiting(t *testing.T) {
gin.SetMode(gin.TestMode)
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
defer mockService.Stop()
r := gin.New()
v1 := r.Group("/api/v1")
RegisterAntigravityHTTPRoutes(v1, mockService)
// 创建一个会话
startBody, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(startBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
var startResult map[string]string
json.Unmarshal(w.Body.Bytes(), &startResult)
cascadeID := startResult["cascade_id"]
// 并发发送 150 个消息,应该有的超过限制
var wg sync.WaitGroup
results := make([]int, 0)
var resultsMutex sync.Mutex
for i := 0; i < 150; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
body, _ := json.Marshal(map[string]string{
"cascade_id": cascadeID,
"message": "Test message " + string(rune(idx)),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
resultsMutex.Lock()
results = append(results, w.Code)
resultsMutex.Unlock()
}(i)
}
wg.Wait()
// 统计结果
successCount := 0
timeoutCount := 0
for _, code := range results {
if code == 200 || code == 500 { // 500 可能是上游 API 错误
successCount++
} else if code == 504 { // 网关超时
timeoutCount++
}
}
// 预期:大部分请求成功(因为有速率限制),但速率限制应该生效
// 限制是 100 并发,所以 150 个请求中应该都能处理(只是可能有等待)
if successCount < 140 {
t.Logf("⚠️ 仅 %d/150 个请求成功(超过限制被拒绝)- 这是预期的速率限制行为", successCount)
}
t.Logf("✅ 速率限制测试完成:成功=%d, 超时=%d", successCount, timeoutCount)
}
// TestSessionCleanup 测试会话超时清理(改进 3
func TestSessionCleanup(t *testing.T) {
gin.SetMode(gin.TestMode)
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
mockService.SetSessionTTL(2) // 设置 2 秒过期,便于测试
defer mockService.Stop()
r := gin.New()
v1 := r.Group("/api/v1")
RegisterAntigravityHTTPRoutes(v1, mockService)
// 创建 5 个会话
cascadeIDs := make([]string, 5)
for i := 0; i < 5; i++ {
body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
var result map[string]string
json.Unmarshal(w.Body.Bytes(), &result)
cascadeIDs[i] = result["cascade_id"]
}
// 验证所有会话存在
sessions := mockService.GetCascadeSessions()
if len(sessions) != 5 {
t.Fatalf("Expected 5 sessions, got %d", len(sessions))
}
t.Log("✅ 创建了 5 个会话")
// 等待清理周期 + TTL
time.Sleep(3 * time.Second)
// 验证会话被清理
sessions = mockService.GetCascadeSessions()
sessionCount := len(sessions)
if sessionCount != 0 {
t.Logf("⚠️ 预期 0 个会话,但仍有 %d 个(可能清理还未执行)", sessionCount)
} else {
t.Log("✅ 过期会话成功清理")
}
}
// TestConcurrentMessageAppend 测试并发安全的消息追加(改进 2
func TestConcurrentMessageAppend(t *testing.T) {
gin.SetMode(gin.TestMode)
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
defer mockService.Stop()
r := gin.New()
v1 := r.Group("/api/v1")
RegisterAntigravityHTTPRoutes(v1, mockService)
// 创建会话
body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
var result map[string]string
json.Unmarshal(w.Body.Bytes(), &result)
cascadeID := result["cascade_id"]
// 并发追加 50 个消息
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
body, _ := json.Marshal(map[string]string{
"cascade_id": cascadeID,
"message": "Concurrent message " + string(rune(idx)),
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
r.ServeHTTP(w, req)
// 不关心返回值,只关心不 panic
}(i)
}
wg.Wait()
// 验证会话中的消息数量
sessions := mockService.GetCascadeSessions()
messageCount := 0
if session, exists := sessions[cascadeID]; exists {
messageCount = len(session.Messages)
}
// 预期1 个初始消息(如果没有 system_prompt则为 0+ 最多 50 个用户消息
// 但由于速率限制,可能不是所有 50 个都会被处理
if messageCount > 0 {
t.Logf("✅ 并发消息追加成功,共 %d 条消息", messageCount)
} else {
t.Log("⚠️ 由于速率限制或其他原因,部分消息未被追加")
}
}

View File

@ -1,254 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestAccount68FullE2E 测试账号 68 的完整端到端流程
// 模拟: curl POST /api/v1/admin/accounts/68/test
func TestAccount68FullE2E(t *testing.T) {
t.Log("🔥 测试账号 68 的完整认证流程...")
t.Log("")
// 准备账号数据(与云端数据一致)
account := &Account{
ID: 68,
Name: "PriesJosephe139@gmail.com",
Platform: PlatformAntigravity,
Type: "oauth",
Credentials: map[string]interface{}{
"_token_version": 1775902256706,
"access_token": "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213",
"email": "priesjosephe139@gmail.com",
"expires_at": "1775907556",
"model_mapping": map[string]interface{}{
"claude-opus-*": "claude-opus-4-6-thinking",
"claude-sonnet-*": "claude-sonnet-4-6-thinking",
},
"plan_type": "Free",
"project_id": "kinetic-sum-r3tp7",
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
"token_type": "Bearer",
},
Extra: map[string]interface{}{
"allow_overages": true,
"privacy_mode": "privacy_set",
},
ProxyID: ptrInt64(9),
Concurrency: 100,
Priority: 1,
Status: "active",
}
t.Log("📌 账号信息:")
t.Logf(" ID: %d", account.ID)
t.Logf(" Name: %s", account.Name)
t.Logf(" Platform: %s", account.Platform)
t.Logf(" Project ID: %v", account.GetCredential("project_id"))
t.Log("")
// 步骤 1: 验证凭证
t.Run("Step1_ValidateCredentials", func(t *testing.T) {
t.Log("步骤 1: 验证账号凭证...")
accessToken := account.GetCredential("access_token")
if accessToken == "" {
t.Fatalf("❌ Access token 为空")
}
t.Logf(" ✓ Access Token 存在 (长度: %d)", len(accessToken))
projectID := account.GetCredential("project_id")
if projectID == "" {
t.Fatalf("❌ Project ID 为空")
}
t.Logf(" ✓ Project ID 存在: %s", projectID)
t.Log("")
})
// 步骤 2: 测试 API 调用(通过 SOCKS5 代理)
t.Run("Step2_CallUpstreamAPI", func(t *testing.T) {
t.Log("步骤 2: 通过 SOCKS5 代理调用上游 API...")
t.Log("")
ctx, cancel := context.WithTimeout(context.Background(), 30)
defer cancel()
// 使用之前测试过的配置
proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760"
accessTokenStr := account.GetCredential("access_token")
t.Logf(" 📤 API 请求:")
t.Logf(" URL: https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist")
t.Logf(" Token: %s... (长度: %d)", accessTokenStr[:30], len(accessTokenStr))
t.Logf(" Proxy: %s", proxyAddr)
t.Log("")
// 创建 HTTP 客户端(使用 SOCKS5 代理)
transport := &http.Transport{}
httpClient := &http.Client{
Transport: transport,
Timeout: 30,
}
req, err := http.NewRequestWithContext(ctx, "POST",
"https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist",
bytes.NewReader([]byte(`{}`)))
if err != nil {
t.Fatalf("❌ 创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer "+accessTokenStr)
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
t.Logf("❌ API 调用失败: %v", err)
t.Logf(" (可能是网络问题,但凭证本身没问题)")
return
}
defer resp.Body.Close()
t.Logf(" ✓ 收到响应")
t.Logf(" HTTP Status: %d", resp.StatusCode)
t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type"))
t.Log("")
// 读取响应
respBody := make([]byte, 2048)
n, _ := resp.Body.Read(respBody)
respText := string(respBody[:n])
if resp.StatusCode == 200 {
t.Log(" ✅ API 调用成功!")
var result map[string]interface{}
if err := json.Unmarshal(respBody[:n], &result); err == nil {
if _, ok := result["cloudaicompanionProject"]; ok {
t.Logf(" ✓ 获得 Project: %v", result["cloudaicompanionProject"])
}
}
} else {
t.Logf(" ❌ API 返回错误 (HTTP %d)", resp.StatusCode)
t.Logf(" 响应: %s", respText)
}
t.Log("")
})
// 步骤 3: 模拟 SSE 响应流(本地)
t.Run("Step3_SimulateSSEResponse", func(t *testing.T) {
t.Log("步骤 3: 模拟 SSE 响应流...")
t.Log("")
gin.SetMode(gin.TestMode)
router := gin.New()
// 模拟成功的 API 响应
successResponse := map[string]interface{}{
"cloudaicompanionProject": "kinetic-sum-r3tp7",
"currentTier": map[string]interface{}{
"id": "free-tier",
"name": "Antigravity",
},
}
router.POST("/test", func(c *gin.Context) {
// 设置 SSE 头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Status(200)
// 发送测试开始
event1 := map[string]interface{}{
"type": "test_start",
"model": "claude-opus-4-6",
}
data1, _ := json.Marshal(event1)
c.Writer.WriteString("data: " + string(data1) + "\n\n")
c.Writer.Flush()
// 发送内容(成功的 API 响应)
event2 := map[string]interface{}{
"type": "content",
"text": "✅ 账号验证成功!",
}
data2, _ := json.Marshal(event2)
c.Writer.WriteString("data: " + string(data2) + "\n\n")
c.Writer.Flush()
// 发送完成
event3 := map[string]interface{}{
"type": "test_complete",
"success": true,
}
data3, _ := json.Marshal(event3)
c.Writer.WriteString("data: " + string(data3) + "\n\n")
c.Writer.Flush()
t.Logf(" 📤 服务器已发送 SSE 事件:")
t.Logf(" 1. test_start (model=%v)", successResponse["cloudaicompanionProject"])
t.Logf(" 2. content (text: ✅ 账号验证成功!)")
t.Logf(" 3. test_complete (success=true)")
})
// 发送请求
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte(`{}`)))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证响应
t.Log("")
t.Log(" 📥 客户端收到的响应:")
body := w.Body.String()
lines := bytes.Split([]byte(body), []byte("\n\n"))
for i, line := range lines {
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]interface{}
if err := json.Unmarshal(data, &event); err == nil {
t.Logf(" 事件 %d: type=%v", i, event["type"])
if content, ok := event["content"]; ok {
t.Logf(" content=%v", content)
}
if success, ok := event["success"]; ok {
t.Logf(" success=%v", success)
}
}
}
}
t.Log("")
})
// 步骤 4: 总结
t.Run("Step4_Summary", func(t *testing.T) {
t.Log("步骤 4: 总结...")
t.Log("")
t.Log("✅ 账号 68 测试完成!")
t.Log("")
t.Log("🎯 关键发现:")
t.Log(" 1. Access Token 已刷新成功 ✅")
t.Log(" 2. Project ID 有效: kinetic-sum-r3tp7 ✅")
t.Log(" 3. 上游 Google API 返回 200 成功 ✅")
t.Log(" 4. SSE 事件正确传递 ✅")
t.Log("")
t.Log("📊 预期结果:")
t.Log(" - 云端测试应该也能成功")
t.Log(" - 不再看到 'IT' 错误")
t.Log("")
})
}
func ptrInt64(i int64) *int64 {
return &i
}

View File

@ -1,91 +0,0 @@
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// TestDirectUpstreamCall 直接调用真实的 Google API看返回什么
func TestDirectUpstreamCall(t *testing.T) {
t.Log("🔥 直接调用 Google API观察真实返回值...")
t.Log("")
accessToken := "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213"
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 步骤 1: 创建客户端
t.Log("步骤 1: 创建 Antigravity 客户端...")
client, err := antigravity.NewClient("")
if err != nil {
t.Fatalf("❌ 创建客户端失败: %v", err)
}
t.Log("✅ 客户端创建成功")
t.Log("")
// 步骤 2: 直接调用 LoadCodeAssist
t.Log("步骤 2: 调用 client.LoadCodeAssist(ctx, accessToken)...")
t.Logf(" AccessToken: %s... (长度: %d)", accessToken[:30], len(accessToken))
t.Log("")
resp, rawResp, err := client.LoadCodeAssist(ctx, accessToken)
// 步骤 3: 分析返回值
t.Log("步骤 3: 分析返回值...")
t.Log("")
if err != nil {
t.Logf("❌ 调用失败")
t.Logf(" 错误类型: %T", err)
t.Logf(" 错误信息: %v", err)
t.Logf(" 错误字符串: %s", err.Error())
t.Logf(" 错误长度: %d 字符", len(err.Error()))
t.Log("")
// 分析错误信息的前几个字符
errStr := err.Error()
if len(errStr) >= 2 {
t.Logf("📊 错误信息的前 5 个字符: '%s'", errStr[:min(5, len(errStr))])
}
t.Log("")
t.Logf("🎯 这就是导致 'IT' 错误的真实原因!")
t.Logf(" 错误完整内容: %q", errStr)
t.Log("")
// 尝试找出 "IT" 的来源
if len(errStr) >= 2 {
first2 := errStr[:2]
t.Logf("📌 错误的前两个字符: '%s'", first2)
if first2 == "IT" {
t.Logf(" ✓ 确认: 'IT' 就是从这个错误截断来的")
} else {
t.Logf(" ⚠️ 前两个字符不是 'IT',可能被其他方式处理了")
}
}
return
}
// 成功的情况
t.Log("✅ 调用成功!")
t.Log("")
if resp != nil {
t.Logf("📋 响应信息:")
t.Logf(" CloudAICompanionProject: %s", resp.CloudAICompanionProject)
t.Logf(" Response 类型: %T", resp)
t.Log("")
// 打印原始响应
if rawResp != nil {
t.Log("📄 原始 API 响应 JSON:")
jsonBytes, _ := json.MarshalIndent(rawResp, " ", " ")
t.Logf("%s", string(jsonBytes))
}
}
}

View File

@ -44,10 +44,9 @@ const (
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
// 使用指数退避策略重试,最多重试 10 次(而非 60 次)
antigravityModelCapacityRetryMaxAttempts = 10
// 使用固定 1s 间隔重试,最多重试 60 次
antigravityModelCapacityRetryMaxAttempts = 60
antigravityModelCapacityRetryWait = 1 * time.Second
antigravityModelCapacityRetryMaxWait = 32 * time.Second // 指数退避上限
// Google RPC 状态和类型常量
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
@ -113,62 +112,6 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError,
return nil, false
}
func isGoogleOneAICreditsEntry(entry map[string]any) bool {
creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string)
creditType = strings.TrimSpace(strings.ToUpper(creditType))
return creditType == "" || creditType == "GOOGLE_ONE_AI"
}
func firstPresent(entry map[string]any, keys ...string) any {
for _, key := range keys {
if value, ok := entry[key]; ok {
return value
}
}
return nil
}
func parseAICreditsInt32(raw any) (int32, bool) {
switch v := raw.(type) {
case int:
return int32(v), true
case int32:
return v, true
case int64:
return int32(v), true
case float32:
return int32(v), true
case float64:
return int32(v), true
case json.Number:
parsed, err := v.Int64()
if err != nil {
floatVal, floatErr := strconv.ParseFloat(v.String(), 64)
if floatErr != nil {
return 0, false
}
return int32(floatVal), true
}
return int32(parsed), true
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return 0, false
}
parsed, err := strconv.ParseInt(trimmed, 10, 32)
if err == nil {
return int32(parsed), true
}
floatVal, floatErr := strconv.ParseFloat(trimmed, 64)
if floatErr != nil {
return 0, false
}
return int32(floatVal), true
default:
return 0, false
}
}
// PromptTooLongError 表示上游明确返回 prompt too long
type PromptTooLongError struct {
StatusCode int
@ -206,28 +149,17 @@ type antigravityRetryLoopResult struct {
}
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
// 根据账号类型选择优先 URL企业账号isGcpTos=true→ prod个人账号 → daily与真实 IDE 一致)。
// 可通过环境变量 GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL=daily 或 =prod 强制覆盖。
func resolveAntigravityForwardBaseURL(account *Account) string {
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
if mode == "daily" {
return "https://daily-cloudcode-pa.googleapis.com"
}
if mode == "prod" {
return "https://cloudcode-pa.googleapis.com"
}
// 按账号类型选择优先 URL
isGcpTos := account != nil && account.GetCredentialAsBool("is_gcp_tos")
urls := antigravity.BaseURLsForAccount(isGcpTos)
if len(urls) == 0 {
// 默认使用 dailyForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
func resolveAntigravityForwardBaseURL() string {
baseURLs := antigravity.ForwardBaseURLs()
if len(baseURLs) == 0 {
return ""
}
// 返回可用列表中的第一个URLAvailability 动态优先级在调用方处理)
available := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(urls)
if len(available) > 0 {
return available[0]
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
if mode == "prod" && len(baseURLs) > 1 {
return baseURLs[1]
}
return urls[0]
return baseURLs[0]
}
// smartRetryAction 智能重试的处理结果
@ -319,7 +251,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
var lastRetryResp *http.Response
var lastRetryBody []byte
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(10 次,指数退避
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔
maxAttempts := antigravitySmartRetryMaxAttempts
if isModelCapacityExhausted {
maxAttempts = antigravityModelCapacityRetryMaxAttempts
@ -346,29 +278,10 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
// 计算本次重试的等待时间
var currentWaitDuration time.Duration
if isModelCapacityExhausted {
// 使用指数退避1s, 2s, 4s, 8s, 16s, 32s, ...
currentWaitDuration = waitDuration * time.Duration(1<<(attempt-1))
if currentWaitDuration > antigravityModelCapacityRetryMaxWait {
currentWaitDuration = antigravityModelCapacityRetryMaxWait
}
// 添加随机抖动±10%)避免羊群效应
jitter := time.Duration(mathrand.Int63n(int64(currentWaitDuration / 5)))
if mathrand.Intn(2) == 0 {
currentWaitDuration += jitter
} else {
currentWaitDuration -= jitter
}
} else {
currentWaitDuration = waitDuration
}
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
p.prefix, resp.StatusCode, attempt, maxAttempts, currentWaitDuration, modelName, p.account.ID)
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
timer := time.NewTimer(currentWaitDuration)
timer := time.NewTimer(waitDuration)
select {
case <-p.ctx.Done():
timer.Stop()
@ -678,7 +591,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
}
}
baseURL := resolveAntigravityForwardBaseURL(p.account)
baseURL := resolveAntigravityForwardBaseURL()
if baseURL == "" {
return nil, errors.New("no antigravity forward base url configured")
}
@ -1084,20 +997,13 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
return mapAntigravityModel(account, requestedModel)
}
// applyThinkingModelSuffix 根据 thinking 配置和模型可用性调整模型名。
// Google v1internal API 上部分 Claude 模型只有 -thinking 后缀版本存在,
// 非 -thinking 版本会返回 404。
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
// claude-opus-4-6: Google API 上只有 -thinking 版本,始终加后缀
if mappedModel == "claude-opus-4-6" {
return "claude-opus-4-6-thinking"
}
// 其他模型仅在 thinking 开启时加后缀
if !thinkingEnabled {
return mappedModel
}
switch mappedModel {
case "claude-sonnet-4-5":
if mappedModel == "claude-sonnet-4-5" {
return "claude-sonnet-4-5-thinking"
}
return mappedModel
@ -1139,10 +1045,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
return nil, fmt.Errorf("model %s not in whitelist", modelID)
}
// 应用 thinking 后缀claude-opus-4-6 → claude-opus-4-6-thinking
// TestConnection 与主请求路径保持一致Google API 只支持 -thinking 后缀版本的部分模型
mappedModel = applyThinkingModelSuffix(mappedModel, false)
// 构建请求体
var requestBody []byte
if strings.HasPrefix(modelID, "gemini-") {
@ -1246,17 +1148,17 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
}
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
// 使用最小 token 消耗:输入 "." + MaxTokens: 10足够验证连接
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
claudeReq := &antigravity.ClaudeRequest{
Model: mappedModel,
Messages: []antigravity.ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`"Test connection"`),
Content: json.RawMessage(`"."`),
},
},
MaxTokens: 10,
MaxTokens: 1,
Stream: false,
}
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
@ -1387,19 +1289,9 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) {
sessionID := ""
if len(preferredSessionID) > 0 {
sessionID = preferredSessionID[0]
}
bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID)
if err != nil {
return nil, fmt.Errorf("补全 sessionId 失败: %w", err)
}
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
if err := json.Unmarshal(bodyWithSessionID, &request); err != nil {
if err := json.Unmarshal(originalBody, &request); err != nil {
return nil, fmt.Errorf("解析请求体失败: %w", err)
}
@ -1477,6 +1369,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if mappedModel == "" {
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
}
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5自动改为 thinking 版本
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
billingModel := mappedModel
@ -1503,9 +1396,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
transformOpts.EnableIdentityPatch = true
transformOpts.PreferredSessionID = sessionID
transformOpts.EnableIdentityPatch = true // 强制启用Antigravity 上游必需
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
@ -1513,8 +1406,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
}
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
// 执行带重试的请求
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
prefix: prefix,
@ -1529,17 +1425,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
accountRepo: s.accountRepo,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
isStickySession: isStickySession, // Forward 由上层判断粘性会话
groupID: 0, // Forward 方法没有 groupID由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash由上层处理粘性会话清除
})
if err != nil {
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusServiceUnavailable,
ForceCacheBilling: switchErr.IsStickySession,
}
}
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
if c.Request.Context().Err() != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
}
@ -1551,6 +1449,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 优先检测 thinking block 的 signature 相关错误400并重试一次
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@ -1567,6 +1468,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Detail: upstreamDetail,
})
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
retryStages := []struct {
name string
strip func(*antigravity.ClaudeRequest) (bool, error)
@ -1586,7 +1491,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
if txErr != nil {
continue
}
@ -1605,8 +1510,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
handleError: s.handleUpstreamError,
requestedModel: originalModel,
isStickySession: isStickySession,
groupID: 0,
sessionHash: "",
groupID: 0, // Forward 方法没有 groupID由上层处理粘性会话清除
sessionHash: "", // Forward 方法没有 sessionHash由上层处理粘性会话清除
})
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
@ -1659,6 +1564,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Detail: retryUpstreamDetail,
})
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
respBody = retryBody
resp = &http.Response{
@ -1669,6 +1575,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
break
}
// Still signature-related; capture context and allow next stage.
respBody = retryBody
resp = &http.Response{
StatusCode: retryResp.StatusCode,
@ -1678,7 +1585,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
// Budget 整流
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
@ -1693,9 +1600,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Detail: s.getUpstreamErrorDetail(respBody),
})
// 修正 claudeReq 的 thinking 参数adaptive 模式不修正)
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
Type: "enabled",
BudgetTokens: BudgetRectifyBudgetTokens,
@ -1750,7 +1659,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
}
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@ -1778,6 +1689,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
if resp.StatusCode == http.StatusBadRequest {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if isGoogleProjectConfigError(msg) {
@ -1828,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
var firstTokenMs *int
var clientDisconnect bool
if claudeReq.Stream {
// 客户端要求流式,直接透传转换
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
@ -1837,6 +1750,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
firstTokenMs = streamRes.firstTokenMs
clientDisconnect = streamRes.clientDisconnect
} else {
// 客户端要求非流式,收集流式响应后转换返回
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
@ -1846,13 +1760,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
firstTokenMs = streamRes.firstTokenMs
}
// DEBUG: 追踪 OAuth Claude 路径的 Usage 在 Forward 返回点的值。
// 若这里 output>0 而 DB 记录为 0说明 bug 在下游billing/record 层);
// 若这里 output=0说明 bug 在 handleClaudeStreamingResponse 或更上游。
logger.LegacyPrintf("service.antigravity_gateway",
"%s DEBUG_USAGE_FORWARD_RETURN input=%d output=%d cache_read=%d cache_creation=%d stream=%v model=%s account=%d",
prefix, usage.InputTokens, usage.OutputTokens, usage.CacheReadInputTokens, usage.CacheCreationInputTokens, claudeReq.Stream, originalModel, account.ID)
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
@ -2249,7 +2156,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID)
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
}
@ -2313,7 +2220,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if fallbackModel != "" && fallbackModel != mappedModel {
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID)
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
@ -2356,7 +2263,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID)
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
if wrapErr == nil {
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctx,
@ -3124,45 +3031,6 @@ func handleStreamReadError(err error, clientDisconnected bool, prefix string) (d
return false, false
}
func googleStatusTextForHTTP(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
case http.StatusServiceUnavailable:
return "UNAVAILABLE"
default:
return "UNKNOWN"
}
}
func buildAnthropicStreamErrorEvent(errType, message string) string {
payload := map[string]any{
"type": "error",
"error": map[string]any{
"type": errType,
"message": message,
},
}
data, _ := json.Marshal(payload)
return "event: error\ndata: " + string(data) + "\n\n"
}
func buildGeminiStreamErrorEvent(status int, message string) string {
payload := map[string]any{
"error": map[string]any{
"code": status,
"message": message,
"status": googleStatusTextForHTTP(status),
},
}
data, _ := json.Marshal(payload)
return "event: error\ndata: " + string(data) + "\n\n"
}
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
@ -3258,12 +3126,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(status int, message string) {
sendErrorEvent := func(reason string) {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
_, _ = fmt.Fprint(c.Writer, buildGeminiStreamErrorEvent(status, message))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
@ -3279,10 +3147,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent(http.StatusBadGateway, "Response too large")
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream read failed")
sendErrorEvent("stream_read_error")
return nil, ev.err
}
@ -3345,7 +3213,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream timeout")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
@ -4105,12 +3973,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
// 仅发送一次错误事件,避免多次写入导致协议混乱
errorEventSent := false
sendErrorEvent := func(errType, message string) {
sendErrorEvent := func(reason string) {
if errorEventSent || cw.Disconnected() {
return
}
errorEventSent = true
_, _ = fmt.Fprint(c.Writer, buildAnthropicStreamErrorEvent(errType, message))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
@ -4126,9 +3994,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if !ok {
// 上游完成,发送结束事件
finalEvents, agUsage := processor.Finish()
logger.LegacyPrintf("service.antigravity_gateway",
"DEBUG_USAGE_PROCESSOR_FINISH input=%d output=%d cache_read=%d image_output=%d final_events_len=%d",
agUsage.InputTokens, agUsage.OutputTokens, agUsage.CacheReadInputTokens, agUsage.ImageOutputTokens, len(finalEvents))
if len(finalEvents) > 0 {
cw.Write(finalEvents)
} else if !processor.MessageStartSent() && !cw.Disconnected() {
@ -4145,15 +4010,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
if ev.err != nil {
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=handleStreamReadError disconnect=%v", disconnect)
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=ErrTooLong max_size=%d error=%v (usage WILL BE ZEROED)", maxLineSize, ev.err)
sendErrorEvent("api_error", "Response too large")
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("api_error", "Upstream stream read failed")
sendErrorEvent("stream_read_error")
return nil, fmt.Errorf("stream read error: %w", ev.err)
}
@ -4179,7 +4043,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
sendErrorEvent("api_error", "Upstream stream timeout")
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
@ -4672,61 +4536,3 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
}
return usage
}
// ForwardRaw 转发 Claude 格式请求并返回原始上游响应体(调用者负责关闭)。
// 不依赖 gin.Context供内部服务如 LanguageServerService调用。
// 复用完整的 token 刷新、模型映射、TLS 指纹和重试逻辑。
func (s *AntigravityGatewayService) ForwardRaw(ctx context.Context, account *Account, body []byte) (io.ReadCloser, int, error) {
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, http.StatusBadRequest, fmt.Errorf("missing model")
}
mappedModel := s.getMappedModel(account, claudeReq.Model)
if mappedModel == "" {
return nil, http.StatusForbidden, fmt.Errorf("model %s not in whitelist", claudeReq.Model)
}
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
if s.tokenProvider == nil {
return nil, http.StatusBadGateway, fmt.Errorf("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, http.StatusBadGateway, fmt.Errorf("failed to get access token: %w", err)
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
transformOpts := s.getClaudeTransformOptions(ctx)
transformOpts.EnableIdentityPatch = true
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("failed to transform request: %w", err)
}
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, geminiBody)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to wrap request: %w", err)
}
upstreamReq, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, wrappedBody)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to build upstream request: %w", err)
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, http.StatusBadGateway, fmt.Errorf("upstream request failed: %w", err)
}
return resp.Body, resp.StatusCode, nil
}

View File

@ -600,120 +600,6 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
require.Equal(t, mappedModel, result.UpstreamModel)
}
func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body, err := json.Marshal(map[string]any{
"contents": []map[string]any{
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
},
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
req.Header.Set("session_id", "session-header-1")
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-session-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
account := &Account{
ID: 16,
Name: "acc-gemini-session",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
},
}
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
var wrapped map[string]any
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
requestNode, ok := wrapped["request"].(map[string]any)
require.True(t, ok)
require.Equal(t, "session-header-1", requestNode["sessionId"])
}
func TestAntigravityGatewayService_Forward_PropagatesSessionHeaderIntoClaudeTransform(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
body := []byte(`{
"model":"claude-sonnet-4-5",
"max_tokens":64,
"messages":[
{"role":"user","content":[{"type":"text","text":"hello"}]}
]
}`)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
req.Header.Set("session_id", "session-header-1")
c.Request = req
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{"X-Request-Id": []string{"req-session-claude-1"}},
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
},
},
}
svc := &AntigravityGatewayService{
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
tokenProvider: &AntigravityTokenProvider{},
httpUpstream: upstream,
}
account := &Account{
ID: 17,
Name: "acc-claude-session",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "token",
"project_id": "project-1",
},
}
result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
var wrapped antigravity.V1InternalRequest
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
require.Equal(t, "session-header-1", wrapped.Request.SessionID)
}
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()

View File

@ -29,9 +29,8 @@ type AntigravityAuthURLResult struct {
State string `json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接。
// isEnterprise=true 时生成企业账号授权链接(使用企业 client_id
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, isEnterprise bool) (*AntigravityAuthURLResult, error) {
// GenerateAuthURL 生成 Google OAuth 授权链接
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
state, err := antigravity.GenerateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %w", err)
@ -59,13 +58,12 @@ func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
IsEnterprise: isEnterprise,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge, isEnterprise)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
return &AntigravityAuthURLResult{
AuthURL: authURL,
@ -91,7 +89,6 @@ type AntigravityTokenInfo struct {
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
IsEnterprise bool `json:"is_enterprise,omitempty"`
ProjectIDMissing bool `json:"-"`
PlanType string `json:"-"`
PrivacyMode string `json:"-"`
@ -122,8 +119,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 交换 token(使用 session 中记录的账号类型)
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier, session.IsEnterprise)
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("token 交换失败: %w", err)
}
@ -140,7 +137,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
IsEnterprise: session.IsEnterprise,
}
// 获取用户信息
@ -170,9 +166,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
return result, nil
}
// RefreshToken 刷新 token。
// isEnterprise=true 时使用企业 OAuth client_id/secret。
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string, isEnterprise bool) (*AntigravityTokenInfo, error) {
// RefreshToken 刷新 token
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
@ -188,7 +183,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
tokenResp, err := client.RefreshToken(ctx, refreshToken, isEnterprise)
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
@ -200,7 +195,6 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
IsEnterprise: isEnterprise,
}, nil
}
@ -217,9 +211,8 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
}
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id
// isEnterprise=true 时使用企业 OAuth client 刷新。
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64, isEnterprise bool) (*AntigravityTokenInfo, error) {
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
@ -228,8 +221,8 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
}
}
// 刷新 token:先按调用方指定类型刷新;若报 client 不匹配再尝试另一侧。
tokenInfo, err := s.refreshTokenAutoFallback(ctx, refreshToken, proxyURL, isEnterprise)
// 刷新 token
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
@ -281,32 +274,6 @@ func isNonRetryableAntigravityOAuthError(err error) bool {
return false
}
// isClientMismatchOAuthError 判断是否为 OAuth client 不匹配错误(用于触发个人/企业切换)。
// 与 isNonRetryableAntigravityOAuthError 不同:这里只识别 client 相关错误,不包含 invalid_grant。
func isClientMismatchOAuthError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "invalid_client") ||
strings.Contains(msg, "unauthorized_client")
}
// refreshTokenAutoFallback 先按指定类型刷新;若遇 client 不匹配错误则切换到另一侧。
// 本函数不承担网络层重试(由内部 RefreshToken 处理)。
func (s *AntigravityOAuthService) refreshTokenAutoFallback(ctx context.Context, refreshToken, proxyURL string, preferEnterprise bool) (*AntigravityTokenInfo, error) {
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, preferEnterprise)
if err == nil {
return tokenInfo, nil
}
if !isClientMismatchOAuthError(err) {
return nil, err
}
// 切换另一侧账号类型重试
fmt.Printf("[AntigravityOAuth] client 不匹配,切换账号类型重试:%v → %v\n", preferEnterprise, !preferEnterprise)
return s.RefreshToken(ctx, refreshToken, proxyURL, !preferEnterprise)
}
// RefreshAccountToken 刷新账户的 token
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
@ -318,8 +285,6 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, fmt.Errorf("无可用的 refresh_token")
}
isEnterprise := account.GetCredentialAsBool("is_gcp_tos")
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
@ -328,7 +293,7 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, isEnterprise)
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
@ -495,7 +460,6 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
"is_gcp_tos": tokenInfo.IsEnterprise,
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken

View File

@ -5,12 +5,13 @@ package service
import (
"bytes"
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
"io"
"net/http"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock

View File

@ -1,187 +0,0 @@
package service
import (
"testing"
)
// TestAntigravityFullFlow 完整流程测试
// 模拟从 HTTP 处理器到最终响应的完整路径
func TestAntigravityFullFlow(t *testing.T) {
t.Log("🔥 启动 Antigravity 完整流程测试...")
t.Log("")
// 构造测试账号数据(使用提供的凭证)
proxyID := int64(9)
account := &Account{
ID: 68,
Name: "PriesJosephe139@gmail.com",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213",
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
"email": "priesjosephe139@gmail.com",
"expires_at": "1775903154",
"project_id": "kinetic-sum-r3tp7",
"plan_type": "Free",
},
ProxyID: &proxyID,
Concurrency: 100,
}
// 测试路由决策逻辑
t.Run("RouteAntigravityTest", func(t *testing.T) {
// 验证账号类型,决定使用哪条路径
t.Logf("📌 账号类型判断:")
t.Logf(" Platform: %s (期望: antigravity)", account.Platform)
t.Logf(" Type: %s (期望: oauth)", account.Type)
t.Logf("")
// 模拟 routeAntigravityTest 的决策逻辑
var testPath string
if account.Type == AccountTypeAPIKey {
testPath = "APIKey 路径 (Claude/Gemini 直接连接)"
} else if account.Platform == PlatformAntigravity {
testPath = "OAuth/Upstream 路径 (使用 AntigravityGatewayService.TestConnection)"
} else {
testPath = "未知路径 (❌ 错误)"
}
t.Logf("✅ 将使用: %s", testPath)
t.Logf("")
})
// 测试完整的错误处理流程
t.Run("ErrorHandlingPathway", func(t *testing.T) {
t.Logf("📋 错误处理流程图:")
t.Logf("")
t.Logf("1⃣ HTTP Handler (account_handler.go:671)")
t.Logf(" ↓")
t.Logf(" accountTestService.TestAccountConnection()")
t.Logf(" ↓")
t.Logf("2⃣ AccountTestService.routeAntigravityTest()")
t.Logf(" ├─ Platform check: antigravity ✓")
t.Logf(" ├─ Type check: oauth ✓")
t.Logf(" └─ Call: testAntigravityAccountConnection()")
t.Logf(" ↓")
t.Logf("3⃣ AccountTestService.testAntigravityAccountConnection()")
t.Logf(" ├─ Send SSE 'test_start' event")
t.Logf(" ├─ Call: AntigravityGatewayService.TestConnection()")
t.Logf(" │ ├─ Get access token")
t.Logf(" │ ├─ Get project_id")
t.Logf(" │ ├─ Build request body")
t.Logf(" │ ├─ Call: antigravityRetryLoop()")
t.Logf(" │ │ ├─ Execute HTTP request to Google API")
t.Logf(" │ │ ├─ Parse response")
t.Logf(" │ │ └─ Handle errors (rate limit, auth, etc.)")
t.Logf(" │ └─ Return result or error")
t.Logf(" ├─ If error: sendErrorAndEnd(error_message)")
t.Logf(" ├─ If success: sendEvent('content', response_text)")
t.Logf(" └─ Send SSE 'test_complete' event")
t.Logf(" ↓")
t.Logf("4⃣ Response to Client (SSE 流)")
t.Logf(" ├─ Content-Type: text/event-stream")
t.Logf(" ├─ Event: test_start")
t.Logf(" ├─ Event: content (或 error)")
t.Logf(" └─ Event: test_complete")
t.Logf("")
})
// 诊断 "IT" 错误的可能来源
t.Run("DiagnoseITError", func(t *testing.T) {
t.Logf("🔍 分析 'IT' 错误可能的来源:")
t.Logf("")
t.Logf("❓ 场景 1: 错误被截断")
t.Logf(" 原始错误可能是:")
t.Logf(" - 'INVALID_TOKEN' → truncated to 'IT'")
t.Logf(" - 'INTERNAL_ERROR' → truncated to 'IT'")
t.Logf(" - 'INVALID_GRANT' → truncated to 'IT'")
t.Logf(" - 'INTERNAL_ERROR...' → first 2 chars 'IN' not 'IT'")
t.Logf("")
t.Logf("❓ 场景 2: 错误来自特定的代码点")
t.Logf(" 可能出现 'IT' 的地方:")
t.Logf(" - SSE stream 中的错误字符")
t.Logf(" - HTTP response body 中的 JSON 解析错误")
t.Logf(" - Google API 返回的错误代码 (如果 Google API 返回 'IT' 作为错误)")
t.Logf("")
t.Logf("❓ 场景 3: 特殊的错误代码")
t.Logf(" 需要检查:")
t.Logf(" - 是否存在名为 'IT' 的错误常量?")
t.Logf(" - Google RPC 状态码中是否有 'IT'")
t.Logf(" - 特定的错误处理中是否会生成 'IT'")
t.Logf("")
})
// 完整的调试检查清单
t.Run("DebugChecklist", func(t *testing.T) {
t.Logf("✅ 完整的调试检查清单:")
t.Logf("")
t.Logf("1. 验证账号信息:")
t.Logf(" [ ] Account ID: %d", account.ID)
t.Logf(" [ ] Platform: %s", account.Platform)
t.Logf(" [ ] Type: %s", account.Type)
t.Logf(" [ ] Access Token: %s... (长度: %d)",
account.GetCredential("access_token")[:20],
len(account.GetCredential("access_token")))
t.Logf(" [ ] Project ID: %s", account.GetCredential("project_id"))
t.Logf("")
t.Logf("2. 验证请求路径:")
t.Logf(" [ ] routeAntigravityTest 选择了正确的路径")
t.Logf(" [ ] testAntigravityAccountConnection 被调用")
t.Logf(" [ ] AntigravityGatewayService.TestConnection 被调用")
t.Logf("")
t.Logf("3. 捕获详细错误信息:")
t.Logf(" [ ] 错误的完整字符串(不仅仅是 'IT'")
t.Logf(" [ ] 错误的类型type")
t.Logf(" [ ] 错误发生的确切代码行")
t.Logf(" [ ] HTTP 状态码(如有)")
t.Logf(" [ ] HTTP 响应体(如有)")
t.Logf("")
t.Logf("4. 验证 SSE 流处理:")
t.Logf(" [ ] 错误事件的 type 字段")
t.Logf(" [ ] 错误事件的 error 字段内容")
t.Logf(" [ ] 是否有 UTF-8 编码问题")
t.Logf("")
})
// 建议的实际代码改进
t.Run("SuggestedCodeFixes", func(t *testing.T) {
t.Logf("🔧 建议的代码改进:")
t.Logf("")
t.Logf("1. 在 testAntigravityAccountConnection 中增加日志:")
t.Logf(" ```go")
t.Logf(" result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID)")
t.Logf(" if err != nil {")
t.Logf(" log.Printf(\"[ERROR] TestConnection failed: type=%%T, error=%%v, msg='%%s'\", err, err, err.Error())")
t.Logf(" return s.sendErrorAndEnd(c, err.Error())")
t.Logf(" }")
t.Logf(" ```")
t.Logf("")
t.Logf("2. 在 sendErrorAndEnd 中增加详细日志:")
t.Logf(" ```go")
t.Logf(" func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, msg string) error {")
t.Logf(" log.Printf(\"[SEND_ERROR] msg='%%s' (len=%%d, bytes=%%v)\", msg, len(msg), []byte(msg))")
t.Logf(" s.sendEvent(c, TestEvent{Type: \"test_error\", Error: msg, Success: false})")
t.Logf(" return nil")
t.Logf(" }")
t.Logf(" ```")
t.Logf("")
t.Logf("3. 检查 TestConnection 中的错误处理:")
t.Logf(" 在 antigravity_gateway_service.go 的 TestConnection 函数中")
t.Logf(" 追踪每个错误返回点的错误信息")
t.Logf("")
})
// 最后的总结
t.Log("")
t.Log("📊 测试摘要:")
t.Log("✅ 账号凭证验证: 通过")
t.Log("✅ 路由逻辑验证: 通过")
t.Log("⚠️ 实际错误诊断: 需要在完整环境中运行")
t.Log("")
t.Log("下一步:")
t.Log("1. 添加建议的代码日志")
t.Log("2. 重新运行 HTTP 测试")
t.Log("3. 收集完整的错误信息")
t.Log("4. 分析并修复根本原因")
}

View File

@ -1,188 +0,0 @@
package service
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestHTTPResponseFlow 测试完整的 HTTP 请求-响应流,看客户端会收到什么
func TestHTTPResponseFlow(t *testing.T) {
t.Log("🔥 模拟完整的 HTTP 请求-响应流...")
t.Log("")
// 创建一个模拟的服务
gin.SetMode(gin.TestMode)
router := gin.New()
// 模拟账号测试端点
router.POST("/api/v1/admin/accounts/:id/test", func(c *gin.Context) {
// 模拟返回错误的情况
// 设置 SSE 头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
// 发送测试开始事件
event1 := map[string]interface{}{
"type": "test_start",
"model": "claude-opus-4-6",
}
jsonData1, _ := json.Marshal(event1)
c.Writer.WriteString("data: " + string(jsonData1) + "\n\n")
c.Writer.Flush()
// 模拟一个错误:比如 "INVALID_TOKEN" 或其他上游错误
// 这里我们故意测试不同的错误信息来看 curl 会显示什么
errorMessages := []string{
"INVALID_TOKEN",
"INTERNAL_ERROR",
"Invalid authentication credentials",
"Th", // 测试短错误
"IT", // 直接测试 "IT"
}
selectedError := errorMessages[3] // 选择第 4 个:这应该显示为 "Th" 而不是 "IT"
event2 := map[string]interface{}{
"type": "error",
"error": selectedError,
"success": false,
}
jsonData2, _ := json.Marshal(event2)
c.Writer.WriteString("data: " + string(jsonData2) + "\n\n")
c.Writer.Flush()
// 发送完成事件
event3 := map[string]interface{}{
"type": "test_complete",
"success": false,
}
jsonData3, _ := json.Marshal(event3)
c.Writer.WriteString("data: " + string(jsonData3) + "\n\n")
c.Writer.Flush()
t.Logf("📤 服务器发送的错误: '%s'", selectedError)
})
// 测试 1: 发送 HTTP 请求
t.Run("SendRequestAndCheckResponse", func(t *testing.T) {
t.Log("步骤 1: 发送 HTTP 请求...")
req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test",
bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6"}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
t.Log("✅ 请求已发送")
t.Log("")
// 步骤 2: 检查响应
t.Log("步骤 2: 分析 HTTP 响应...")
t.Logf(" HTTP Status: %d", w.Code)
t.Logf(" Content-Type: %s", w.Header().Get("Content-Type"))
t.Log("")
// 步骤 3: 读取 SSE 响应
t.Log("步骤 3: 读取 SSE 事件...")
body := w.Body.String()
t.Logf(" 响应总长度: %d 字节", len(body))
t.Log("")
// 解析 SSE 事件
lines := bytes.Split([]byte(body), []byte("\n\n"))
for i, line := range lines {
if len(line) == 0 {
continue
}
// 去掉 "data: " 前缀
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]interface{}
err := json.Unmarshal(data, &event)
if err != nil {
t.Logf(" 事件 %d: [解析失败] %v", i, err)
continue
}
t.Logf(" 事件 %d:", i)
t.Logf(" type: %v", event["type"])
if errMsg, ok := event["error"]; ok {
t.Logf(" error: %v (长度: %d)", errMsg, len(errMsg.(string)))
// 这就是 curl 会看到的错误信息
errStr := errMsg.(string)
if errStr == "IT" {
t.Logf(" ✓ 发现 'IT' 错误!")
} else if errStr == "Th" {
t.Logf(" 这是 'Th' 而不是 'IT'")
} else {
t.Logf(" 实际错误: '%s'", errStr)
}
}
if model, ok := event["model"]; ok {
t.Logf(" model: %v", model)
}
}
}
t.Log("")
t.Log("📋 完整的原始响应:")
t.Logf("%s", body)
})
// 测试 2: 模拟真实的 curl 请求
t.Run("SimulateRealCurlRequest", func(t *testing.T) {
t.Log("步骤: 模拟真实 curl 命令...")
t.Log("")
// 发送请求
req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test",
bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6","prompt":""}`)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-token")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 模拟 curl 读取响应
body := w.Body.String()
t.Log("curl 会看到:")
t.Log("```")
t.Log(body)
t.Log("```")
})
}
// 辅助函数:提取 SSE 事件中的错误信息
func extractErrorFromSSE(sseBody string) string {
lines := bytes.Split([]byte(sseBody), []byte("\n\n"))
for _, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]interface{}
if err := json.Unmarshal(data, &event); err != nil {
continue
}
if errMsg, ok := event["error"]; ok {
return errMsg.(string)
}
}
}
return ""
}

View File

@ -1,213 +0,0 @@
package service
import (
"encoding/json"
"strconv"
"testing"
"time"
)
// TestAntigravityCredentialsValidation 单例测试:验证给定的 Antigravity 账号凭证有效性
// 本测试使用服务器的真实代码函数,不依赖 HTTP 层,模拟云端场景
func TestAntigravityCredentialsValidation(t *testing.T) {
// 测试数据:来自你提供的账号信息
// ID: 68, 平台: antigravity, 类型: oauth
proxyID := int64(9)
testAccount := &Account{
ID: 68,
Name: "PriesJosephe139@gmail.com",
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213",
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
"email": "priesjosephe139@gmail.com",
"expires_at": "1775903154",
"project_id": "kinetic-sum-r3tp7",
"plan_type": "Free",
},
ProxyID: &proxyID,
Concurrency: 100,
}
// 测试 1: 验证账号凭证完整性
t.Run("ValidateAccountCredentials", func(t *testing.T) {
if testAccount.ID == 0 {
t.Fatal("Account ID is missing")
}
if testAccount.Platform != PlatformAntigravity {
t.Fatalf("Expected platform %s, got %s", PlatformAntigravity, testAccount.Platform)
}
if testAccount.Type != AccountTypeOAuth {
t.Fatalf("Expected type %s, got %s", AccountTypeOAuth, testAccount.Type)
}
// 验证必要的凭证字段
accessToken := testAccount.GetCredential("access_token")
if accessToken == "" {
t.Fatal("Access token is missing")
}
refreshToken := testAccount.GetCredential("refresh_token")
if refreshToken == "" {
t.Fatal("Refresh token is missing")
}
projectID := testAccount.GetCredential("project_id")
if projectID == "" {
t.Fatal("Project ID is missing")
}
t.Log("✅ 账号凭证完整性验证通过")
t.Logf(" Account ID: %d, Email: %s, ProjectID: %s", testAccount.ID, testAccount.GetCredential("email"), projectID)
})
// 测试 2: 测试 token 映射和模型验证
t.Run("ValidateModelMapping", func(t *testing.T) {
testModels := []string{
"claude-opus-4-6",
"claude-sonnet-4-6",
"gemini-3-pro-preview",
}
for _, model := range testModels {
t.Logf("✓ Model %s is supported for account", model)
}
t.Log("✅ 模型映射验证通过")
})
// 测试 3: 构建测试请求(不实际发送,只验证格式)
t.Run("BuildTestRequest", func(t *testing.T) {
projectID := testAccount.GetCredential("project_id")
if projectID == "" {
t.Skip("Project ID not available, skipping request building")
}
// 构建 Claude 测试请求的简化版本
claudeReq := map[string]any{
"model": "claude-opus-4-6",
"messages": []map[string]any{
{
"role": "user",
"content": []map[string]any{
{
"type": "text",
"text": ".",
},
},
},
},
"max_tokens": 1,
"stream": true,
}
requestBody, err := json.Marshal(claudeReq)
if err != nil {
t.Fatalf("Failed to marshal request: %v", err)
}
t.Logf("✅ 请求体构建成功,大小: %d bytes", len(requestBody))
if len(requestBody) > 200 {
t.Logf(" 请求格式: %s...", string(requestBody[:200]))
} else {
t.Logf(" 请求格式: %s", string(requestBody))
}
})
// 测试 4: 验证 Token 信息格式
t.Run("ValidateTokenInfo", func(t *testing.T) {
expiresAtStr := testAccount.GetCredential("expires_at")
if expiresAtStr == "" {
t.Log("⚠️ No expires_at timestamp found")
return
}
// 尝试解析时间戳
expiresAtUnix, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil {
expiresAt := time.Unix(expiresAtUnix, 0)
now := time.Now()
if expiresAt.After(now) {
remainingTime := expiresAt.Sub(now)
t.Logf("✅ Token 有效期检查通过")
t.Logf(" 过期时间: %s (还有 %v)", expiresAt.Format("2006-01-02 15:04:05 MST"), remainingTime)
} else {
t.Logf("⚠️ Token 已过期: %s", expiresAt.Format("2006-01-02 15:04:05 MST"))
t.Log(" 预期行为: 应该刷新 refresh_token")
}
}
})
// 测试 5: 创建 Antigravity 客户端并验证连接(如果可行)
t.Run("InitializeAntigravityClient", func(t *testing.T) {
// 使用账号的代理信息初始化客户端
if testAccount.ProxyID != nil {
t.Logf("Account uses proxy ID: %d", *testAccount.ProxyID)
}
t.Log("📌 Antigravity 客户端初始化代码路径:")
t.Log(" 1. 使用 accessToken 创建 antigravity.NewClient(proxyURL)")
t.Log(" 2. 调用 client.LoadCodeAssist(ctx, accessToken) 验证凭证")
t.Log(" 3. 检查响应中的 CloudAICompanionProject 字段")
t.Log("")
t.Log(" 预期行为:")
t.Log(" ✓ projectID == 'kinetic-sum-r3tp7'")
t.Log(" ✓ statusCode 200")
t.Log(" ✓ 无错误返回")
})
// 测试 6: 验证账号支持的操作
t.Run("VerifyAccountOperations", func(t *testing.T) {
operations := []string{
"GetAccessToken",
"RefreshToken",
"LoadCodeAssist",
"GetUserInfo",
"SetPrivacy",
}
for _, op := range operations {
t.Logf("✓ Operation supported: %s", op)
}
t.Log("")
t.Log("✅ 账号支持的操作列表验证通过")
})
// 测试 7: 文档化测试流程(实际调用时的步骤)
t.Run("DocumentTestFlow", func(t *testing.T) {
t.Log("📝 本地测试 Antigravity 账号的完整流程:")
t.Log("")
t.Log("步骤 1: 初始化服务")
t.Log(" - accountRepo: 从数据库获取账号")
t.Log(" - tokenProvider: Antigravity Token 提供者")
t.Log(" - httpUpstream: HTTP 请求执行器")
t.Log(" - gatewayService: Antigravity 网关服务")
t.Log("")
t.Log("步骤 2: 验证账号凭证")
t.Log(" account := accountRepo.GetByID(ctx, 68)")
t.Log(" accessToken := account.GetCredential('access_token')")
t.Log(" projectID := account.GetCredential('project_id')")
t.Log("")
t.Log("步骤 3: 构建测试请求")
t.Log(" requestBody := gatewayService.buildClaudeTestRequest(projectID, 'claude-opus-4-6')")
t.Log("")
t.Log("步骤 4: 执行请求")
t.Log(" result := gatewayService.TestConnection(ctx, account, 'claude-opus-4-6')")
t.Log("")
t.Log("步骤 5: 处理结果")
t.Log(" if err != nil {")
t.Log(" // 记录错误详情")
t.Log(" }")
t.Log("")
t.Log("⚠️ 当前问题:返回了 'IT' 错误")
t.Log(" 这可能表示:")
t.Log(" 1. 错误消息被截断或编码错误")
t.Log(" 2. HTTP 响应体包含不完整的错误文本")
t.Log(" 3. 上游 API 返回的错误被不正确地处理")
})
t.Log("")
t.Log("✅ 所有本地验证测试完成!")
t.Log("")
t.Log("下一步:在实际环境中运行完整测试")
}

View File

@ -1,194 +0,0 @@
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"testing"
"time"
"golang.org/x/net/proxy"
)
// TestWithSOCKS5Proxy 使用指定的 SOCKS5 代理调用上游 API
func TestWithSOCKS5Proxy(t *testing.T) {
t.Log("🔥 使用 SOCKS5 代理调用 Google API...")
t.Log("")
// SOCKS5 代理配置
proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760"
accessToken := "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213"
t.Log("📌 代理信息:")
t.Logf(" 代理地址: %s", proxyAddr)
t.Logf(" 访问令牌: %s... (长度: %d)", accessToken[:30], len(accessToken))
t.Log("")
// 创建上下文和超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 步骤 1: 设置 SOCKS5 代理
t.Run("SetupSOCKS5Proxy", func(t *testing.T) {
t.Log("步骤 1: 配置 SOCKS5 代理...")
// 解析代理 URL
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
t.Fatalf("❌ 解析代理 URL 失败: %v", err)
}
t.Logf(" ✓ 代理 URL 解析成功")
t.Logf(" Scheme: %s", proxyURL.Scheme)
t.Logf(" Host: %s", proxyURL.Host)
t.Logf(" User: %s", proxyURL.User.Username())
t.Log("")
// 创建代理拨号器
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
t.Fatalf("❌ 创建代理拨号器失败: %v", err)
}
t.Log(" ✓ 代理拨号器创建成功")
t.Log("")
// 创建自定义传输
transport := &http.Transport{
Dial: dialer.Dial,
}
// 创建自定义 HTTP 客户端
httpClient := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
t.Log(" ✓ HTTP 客户端创建成功")
t.Log("")
// 步骤 2: 测试代理连接
t.Log("步骤 2: 测试代理连接...")
// 尝试一个简单的 HTTP 请求来测试代理
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.google.com", nil)
if err != nil {
t.Logf("❌ 创建测试请求失败: %v", err)
return
}
resp, err := httpClient.Do(req)
if err != nil {
t.Logf("❌ 通过代理访问 Google 失败: %v", err)
t.Log(" (这可能表示代理配置或网络连接有问题)")
return
}
defer resp.Body.Close()
t.Logf(" ✓ 代理连接成功!")
t.Logf(" HTTP Status: %d", resp.StatusCode)
t.Log("")
})
// 步骤 3: 使用代理调用 Antigravity API
t.Run("CallAntigravityWithProxy", func(t *testing.T) {
t.Log("步骤 3: 通过代理调用 Antigravity API...")
t.Log("")
// 解析代理 URL
proxyURL, err := url.Parse(proxyAddr)
if err != nil {
t.Fatalf("❌ 解析代理 URL 失败: %v", err)
}
// 创建代理拨号器
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
t.Fatalf("❌ 创建代理拨号器失败: %v", err)
}
// 创建自定义传输
transport := &http.Transport{
Dial: dialer.Dial,
}
// 这里我们需要修改 antigravity.Client 来使用自定义的 HTTP 客户端
// 但由于 antigravity.NewClient 可能不支持自定义客户端,
// 我们直接创建一个 HTTP 客户端来调用 API
httpClient := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
t.Log(" 正在调用 Google Cloud Code API...")
t.Log("")
// 直接构造 API 请求
apiURL := "https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist"
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, nil)
if err != nil {
t.Fatalf("❌ 创建请求失败: %v", err)
}
// 添加认证头
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "Antigravity Client")
t.Logf(" 📤 请求信息:")
t.Logf(" URL: %s", apiURL)
t.Logf(" Method: POST")
t.Logf(" Auth: Bearer %s...", accessToken[:30])
t.Log("")
// 发送请求
t.Log(" ⏳ 正在等待响应...")
resp, err := httpClient.Do(req)
if err != nil {
t.Logf("❌ API 调用失败:")
t.Logf(" 错误类型: %T", err)
t.Logf(" 错误信息: %v", err)
t.Logf(" 错误字符串: %s", err.Error())
t.Log("")
// 分析错误
errStr := err.Error()
if len(errStr) >= 2 {
t.Logf("📊 错误的前 5 个字符: '%s'", errStr[:min(5, len(errStr))])
if errStr[:2] == "IT" {
t.Logf(" ✓ 找到了! 这就是 'IT' 错误的来源!")
}
}
return
}
defer resp.Body.Close()
t.Logf("✅ API 调用成功!")
t.Logf(" HTTP Status: %d", resp.StatusCode)
t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type"))
t.Log("")
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Logf("❌ 读取响应失败: %v", err)
return
}
t.Log("📋 API 响应:")
if resp.StatusCode == 200 {
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err == nil {
jsonBytes, _ := json.MarshalIndent(result, " ", " ")
t.Logf(" %s", string(jsonBytes))
} else {
t.Logf(" %s", string(respBody))
}
} else {
t.Logf(" 状态码: %d", resp.StatusCode)
t.Logf(" 错误响应: %s", string(respBody))
}
})
}

View File

@ -1,20 +0,0 @@
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func TestShouldMarkTempUnschedulableForRefreshError(t *testing.T) {
t.Run("skip global oauth client secret missing", func(t *testing.T) {
err := errors.New(`token 刷新失败 (重试后): error: code=400 reason="ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING" message="missing antigravity oauth client_secret; set ANTIGRAVITY_OAUTH_CLIENT_SECRET" metadata=map[]`)
require.False(t, shouldMarkTempUnschedulableForRefreshError(err))
})
t.Run("allow account specific refresh error", func(t *testing.T) {
err := errors.New("token 刷新失败 (重试后): invalid_grant")
require.True(t, shouldMarkTempUnschedulableForRefreshError(err))
})
}

View File

@ -1,83 +0,0 @@
package service
import (
"context"
"log/slog"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// WarmupAntigravityAccount 预热新的 Antigravity 账号
// 在账号创建后立即调用,避免首次请求的 503 延迟
//
// 预热流程:
// 1. GetUserInfo - 验证 token 有效性
// 2. LoadCodeAssist - 初始化项目信息
// 3. FetchAvailableModels - 初始化模型列表
//
// 总耗时通常 4-6 秒,预热期间的失败不影响账号创建结果(非阻塞)
func (s *AntigravityOAuthService) WarmupAntigravityAccount(ctx context.Context, accessToken, projectID, proxyURL string) {
logger := slog.Default()
// 5 秒超时预热(防止卡住其他操作)
warmupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
client, err := antigravity.NewClient(proxyURL)
if err != nil {
logger.Warn("antigravity_warmup_client_creation_failed", "error", err)
return
}
start := time.Now()
defer func() {
elapsed := time.Since(start)
logger.Info("antigravity_account_warmup_completed", "elapsed_ms", elapsed.Milliseconds())
}()
// Step 1: 验证 token
_, err = client.GetUserInfo(warmupCtx, accessToken)
if err != nil {
logger.Warn("antigravity_warmup_get_user_info_failed", "error", err)
// 继续后续步骤(部分失败不中止)
}
// Step 2: 初始化项目信息
_, _, err = client.LoadCodeAssist(warmupCtx, accessToken)
if err != nil {
logger.Warn("antigravity_warmup_load_code_assist_failed", "error", err)
}
// Step 3: 初始化模型列表
if projectID != "" {
_, _, err := client.FetchAvailableModels(warmupCtx, accessToken, projectID)
if err != nil {
logger.Warn("antigravity_warmup_fetch_available_models_failed", "error", err)
}
}
}
// WarmupOptions 预热选项
type WarmupOptions struct {
// Async 为 true 时在后台预热(推荐)
Async bool
// Timeout 单次预热操作的超时时间
Timeout time.Duration
}
// WarmupAntigravityAccountAsync 异步预热账号(推荐用法)
func (s *AntigravityOAuthService) WarmupAntigravityAccountAsync(ctx context.Context, accessToken, projectID, proxyURL string, opts *WarmupOptions) {
if opts == nil {
opts = &WarmupOptions{
Async: true,
Timeout: 5 * time.Second,
}
}
if opts.Async {
go s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL)
} else {
s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL)
}
}

View File

@ -1,8 +1,9 @@
package service
import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// ChannelMonitorRequestTemplate 请求模板service 层模型)。

View File

@ -1,284 +0,0 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"regexp"
"strings"
"sync"
"time"
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// Attribution block constants matching real Claude Code 2.1.89.
// Source: src/constants/system.ts + src/utils/fingerprint.ts
const (
// fingerprintSalt must match the hardcoded salt in the real CLI.
// Source: extracted/src/utils/fingerprint.ts:8
fingerprintSalt = "59cf53e54c78"
)
type attributionBlockOptions struct {
Entrypoint string
Workload string
OmitCCH bool
}
// computeAttributionFingerprint computes a 3-character hex fingerprint
// matching the algorithm in the real Claude Code CLI.
//
// Algorithm: SHA256(SALT + msg[4] + msg[7] + msg[20] + version)[:3]
// Source: extracted/src/utils/fingerprint.ts:50-63
func computeAttributionFingerprint(firstUserMessageText, cliVersion string) string {
indices := [3]int{4, 7, 20}
chars := make([]byte, 0, 3)
for _, i := range indices {
if i < len(firstUserMessageText) {
chars = append(chars, firstUserMessageText[i])
} else {
chars = append(chars, '0')
}
}
input := fmt.Sprintf("%s%s%s", fingerprintSalt, string(chars), cliVersion)
hash := sha256.Sum256([]byte(input))
return hex.EncodeToString(hash[:])[:3]
}
// extractFirstUserMessageText extracts text from the first user message in the body.
// Handles both string content and array content (text blocks).
func extractFirstUserMessageText(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return ""
}
var firstText string
messages.ForEach(func(_, msg gjson.Result) bool {
if msg.Get("role").String() != "user" {
return true // continue
}
content := msg.Get("content")
if content.Type == gjson.String {
firstText = content.String()
return false // break
}
if content.IsArray() {
content.ForEach(func(_, block gjson.Result) bool {
if block.Get("type").String() == "text" {
firstText = block.Get("text").String()
return false
}
return true
})
return false
}
return true
})
return firstText
}
// buildAttributionBlock builds the x-anthropic-billing-header attribution string
// that real Claude Code injects as the first system text block.
//
// Format: x-anthropic-billing-header: cc_version=<VERSION>.<fingerprint>; cc_entrypoint=cli; cch=00000;
// Source: extracted/src/constants/system.ts:73-95
func buildAttributionBlock(cliVersion, fingerprint string, opts attributionBlockOptions) string {
if claude.AttributionHeaderDisabled() {
return ""
}
version := cliVersion + "." + fingerprint
entrypoint := strings.TrimSpace(opts.Entrypoint)
if entrypoint == "" {
entrypoint = claude.CurrentEntrypoint()
}
workload := strings.TrimSpace(opts.Workload)
if workload == "" {
workload = claude.CurrentWorkload()
}
var b strings.Builder
b.Grow(96)
fmt.Fprintf(&b, "x-anthropic-billing-header: cc_version=%s; cc_entrypoint=%s;", version, entrypoint)
if !opts.OmitCCH {
// 2.1.89+ 的 Claude Code 在 1P / standard providers 下保留 cch=00000 占位符,
// 由下游 attestation / signing 逻辑在需要时替换。
b.WriteString(" cch=00000;")
}
if workload != "" {
fmt.Fprintf(&b, " cc_workload=%s;", workload)
}
return b.String()
}
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
// as the very first system text block in the request body.
// This must come BEFORE the "You are Claude Code" block.
//
// The real CLI injects this as system[0] with cache_control: {type: "ephemeral"}.
func injectAttributionBlock(body []byte, cliVersion string, opts attributionBlockOptions) []byte {
// Compute fingerprint from the first user message
firstMsgText := extractFirstUserMessageText(body)
fingerprint := computeAttributionFingerprint(firstMsgText, cliVersion)
attribution := buildAttributionBlock(cliVersion, fingerprint, opts)
if attribution == "" {
return body
}
// Build the attribution text block as JSON
attrBlock, err := marshalAnthropicSystemTextBlock(attribution, true)
if err != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to build attribution block: %v", err)
return body
}
systemResult := gjson.GetBytes(body, "system")
// Handle the different system formats
switch {
case !systemResult.Exists() || systemResult.Type == gjson.Null:
// No system field — inject just the attribution block
newBody, err := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock}))
if err != nil {
return body
}
return newBody
case systemResult.Type == gjson.String:
// String system — convert to array: [attribution, original]
origBlock, err := marshalAnthropicSystemTextBlock(systemResult.String(), false)
if err != nil {
return body
}
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock, origBlock}))
if setErr != nil {
return body
}
return newBody
case systemResult.IsArray():
// Array system — check if attribution already exists, prepend if not
var items [][]byte
alreadyHasAttribution := false
systemResult.ForEach(func(_, item gjson.Result) bool {
if item.Get("type").String() == "text" {
text := item.Get("text").String()
if len(text) > 30 && text[:30] == "x-anthropic-billing-header: cc" {
alreadyHasAttribution = true
}
}
return true
})
if alreadyHasAttribution {
return body
}
items = append(items, attrBlock)
systemResult.ForEach(func(_, item gjson.Result) bool {
items = append(items, []byte(item.Raw))
return true
})
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw(items))
if setErr != nil {
return body
}
return newBody
default:
return body
}
}
// cliSessionEntry holds a cached session UUID with an expiration time.
type cliSessionEntry struct {
id string
expiresAt time.Time
}
// cliSessionCache stores per-account session UUIDs that rotate on a TTL.
// Real CLI creates a new random UUID per process invocation; we approximate
// this by rotating every 30-60 minutes (jittered per account).
var (
cliSessionCache = make(map[int64]cliSessionEntry)
cliSessionCacheMu sync.Mutex
)
// sessionTTLBase is the base TTL for session ID rotation.
const sessionTTLBase = 30 * time.Minute
// generateSessionIDForAccount returns a per-account session UUID that rotates
// periodically. Each account gets a random TTL jitter (0-30 min on top of
// the 30 min base) so accounts don't all rotate simultaneously.
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
cliSessionCacheMu.Lock()
defer cliSessionCacheMu.Unlock()
now := time.Now()
if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) {
return entry.id
}
// Compute per-account jitter from a hash so the same account always gets
// the same jitter within a process (avoids re-rolling on every rotation).
jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID)
h := sha256.Sum256([]byte(jitterSeed))
jitterMinutes := int(h[0]) % 31 // 0-30 minutes
ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute
newID := uuid.New().String()
cliSessionCache[accountID] = cliSessionEntry{
id: newID,
expiresAt: now.Add(ttl),
}
return newID
}
// reUserHome matches /Users/<username>/ or /home/<username>/ path segments.
// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username.
var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`)
// reEnvLine matches lines of the form "Key: value" for the environment block
// fields injected by Claude Code's CLAUDE.md / sysprompt machinery.
var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`)
// canonicalEnvValues maps environment block keys to their canonical replacements.
// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine.
var canonicalEnvValues = map[string]string{
"Platform": "Platform: darwin",
"Shell": "Shell: zsh",
"OS Version": "OS Version: Darwin 24.4.0",
"Working directory": "Working directory: /Users/user/project",
}
// NormalizeSystemPromptEnv rewrites environment-specific fields in a system
// prompt text block to canonical values, preventing real machine fingerprinting.
//
// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText):
// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0
// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/"
//
// Only called on system prompt text blocks, never on user message content.
func NormalizeSystemPromptEnv(text string) string {
// Replace env-info lines with canonical values
text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string {
for key, canonical := range canonicalEnvValues {
if len(line) >= len(key) && line[:len(key)] == key {
return canonical
}
}
return line
})
// Redact real usernames in home directory paths
// e.g. /Users/alice/project -> /Users/user/project
text = reUserHome.ReplaceAllString(text, "${1}user/")
return text
}

View File

@ -1,81 +0,0 @@
package service
import (
"net/http"
"testing"
)
func TestApplyClaudeRuntimeOptionalHeaders(t *testing.T) {
t.Setenv("CLAUDE_CODE_CONTAINER_ID", "ctr-123")
t.Setenv("CLAUDE_CODE_REMOTE_SESSION_ID", "remote-456")
t.Setenv("CLAUDE_AGENT_SDK_CLIENT_APP", "desktop")
t.Setenv("CLAUDE_CODE_ADDITIONAL_PROTECTION", "true")
req, err := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
applyClaudeRuntimeOptionalHeaders(req)
if got := getHeaderRaw(req.Header, "x-claude-remote-container-id"); got != "ctr-123" {
t.Fatalf("x-claude-remote-container-id = %q", got)
}
if got := getHeaderRaw(req.Header, "x-claude-remote-session-id"); got != "remote-456" {
t.Fatalf("x-claude-remote-session-id = %q", got)
}
if got := getHeaderRaw(req.Header, "x-client-app"); got != "desktop" {
t.Fatalf("x-client-app = %q", got)
}
if got := getHeaderRaw(req.Header, "x-anthropic-additional-protection"); got != "true" {
t.Fatalf("x-anthropic-additional-protection = %q", got)
}
}
func TestBuildAttributionBlock_UsesEntrypointAndWorkload(t *testing.T) {
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "")
got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{
Entrypoint: "sdk-cli",
Workload: "cron",
})
want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=sdk-cli; cch=00000; cc_workload=cron;"
if got != want {
t.Fatalf("buildAttributionBlock() = %q, want %q", got, want)
}
}
func TestBuildAttributionBlock_OmitsCCHForBedrockLikeProviders(t *testing.T) {
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "")
got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{
Entrypoint: "cli",
OmitCCH: true,
})
want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=cli;"
if got != want {
t.Fatalf("buildAttributionBlock() = %q, want %q", got, want)
}
}
func TestInjectAttributionBlock_DisabledByEnv(t *testing.T) {
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "false")
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
got := injectAttributionBlock(body, "2.1.104", attributionBlockOptions{})
if string(got) != string(body) {
t.Fatalf("injectAttributionBlock() should keep body unchanged when attribution disabled")
}
}
func TestShouldOmitAttributionCCH(t *testing.T) {
if !shouldOmitAttributionCCH(&Account{Type: AccountTypeBedrock}, "") {
t.Fatal("expected bedrock account to omit cch")
}
if !shouldOmitAttributionCCH(&Account{Extra: map[string]any{"provider": "mantle"}}, "") {
t.Fatal("expected mantle provider to omit cch")
}
if shouldOmitAttributionCCH(&Account{Type: AccountTypeOAuth}, "oauth") {
t.Fatal("expected oauth account to keep cch")
}
}

View File

@ -1,47 +0,0 @@
package service
import (
"net/http"
"strings"
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
)
func applyClaudeRuntimeOptionalHeaders(req *http.Request) {
if req == nil {
return
}
for key, value := range claude.OptionalAPIHeaders() {
if strings.TrimSpace(value) == "" {
continue
}
setHeaderRaw(req.Header, resolveWireCasing(key), value)
}
}
func attributionOptionsForRequest(account *Account, tokenType string) attributionBlockOptions {
return attributionBlockOptions{
Entrypoint: claude.CurrentEntrypoint(),
Workload: claude.CurrentWorkload(),
OmitCCH: shouldOmitAttributionCCH(account, tokenType),
}
}
func shouldOmitAttributionCCH(account *Account, tokenType string) bool {
if strings.EqualFold(strings.TrimSpace(tokenType), "bedrock") {
return true
}
if account == nil {
return false
}
if account.Type == AccountTypeBedrock {
return true
}
for _, key := range []string{"provider", "upstream_provider"} {
switch strings.ToLower(strings.TrimSpace(account.GetExtraString(key))) {
case "bedrock", "anthropicaws", "anthropic_aws", "mantle":
return true
}
}
return false
}

View File

@ -901,9 +901,6 @@ func sanitizeSystemText(text string) string {
"You are OpenCode, the best coding agent on the planet.",
strings.TrimSpace(claudeCodeSystemPrompt),
)
// Normalize environment block fields (Platform/Shell/OS Version/Working directory)
// to canonical values so different client machines don't create fingerprint divergence.
text = NormalizeSystemPromptEnv(text)
return text
}
@ -4230,22 +4227,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
}
// 注入 x-anthropic-billing-header attribution block所有 OAuth 账号)
// 真实 CLI 在 system prompt 的第一个 text block 注入此 billing header。
// 用于 Anthropic 后端路由和验证。
// 跳过条件system 已被 rewriteSystemForNonClaudeCode 重写claudeCodeSystemPrompt 在 system[0]
// 注入会将其移到 system[1],破坏伪装结构及 system[0] 断言。
if account.IsOAuth() && !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemRewritten {
// 获取 CLI 版本:优先用指纹中的版本,回退到默认
attrCLIVersion := claude.DefaultCLIVersion
if fp := getHeaderRaw(c.Request.Header, "User-Agent"); fp != "" {
if v := ExtractCLIVersion(fp); v != "" {
attrCLIVersion = v
}
}
body = injectAttributionBlock(body, attrCLIVersion, attributionOptionsForRequest(account, "oauth"))
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
@ -5926,35 +5907,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id
// 2. 客户端未提供mimic 模式)→ 生成确定性 per-account session UUID
// 真实 CLI 每个请求都携带此 headerper-process UUID
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
} else if tokenType == "oauth" {
// mimic 模式:生成 session-id
var sessionID string
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
sessionID = parsed.SessionID
}
}
if sessionID == "" {
salt := ""
if s.cfg != nil {
salt = s.cfg.Gateway.InstanceSalt
}
sessionID = generateSessionIDForAccount(salt, account.ID)
}
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
}
// x-client-request-id: 真实 CLI 每个请求生成新 UUID仅 1P
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
}
applyClaudeRuntimeOptionalHeaders(req)
// === DEBUG: 打印上游转发请求headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
@ -8984,35 +8949,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// X-Claude-Code-Session-Id 头处理count_tokens 路径)
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
} else if tokenType == "oauth" {
var sessionID string
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
sessionID = parsed.SessionID
}
}
if sessionID == "" {
salt := ""
if s.cfg != nil {
salt = s.cfg.Gateway.InstanceSalt
}
sessionID = generateSessionIDForAccount(salt, account.ID)
}
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
}
// x-client-request-idcount_tokens 路径)
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
}
applyClaudeRuntimeOptionalHeaders(req)
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))

View File

@ -13,7 +13,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

View File

@ -1,28 +0,0 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
// ==============================================================
// antigravity — identity_service 扩展
//
// 此文件包含 Antigravity fork 对 IdentityService 的扩展,
// 新增了实例级隔离盐值和指纹默认值覆盖功能。
//
// 对上游文件 identity_service.go 的最小化改动:
// - defaultFingerprint 版本号更新
// - IdentityService struct 新增 instanceSalt 字段
// ==============================================================
// ApplyDefaultFingerprintOverrides 用配置覆盖 identity_service 的默认指纹
// 允许不同部署实例设置不同的 CLI/SDK 版本号,避免所有实例指纹相同
func ApplyDefaultFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) {
claude.ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch)
defaultFingerprint = defaultIdentityFingerprint()
}
// NewIdentityServiceWithSalt 创建带实例盐值的 IdentityService
// 实例盐值用于 user_id 重写时的 session hash 混淆,
// 使不同 sub2api 实例对相同输入产生不同的 hash 输出,增加隔离性
func NewIdentityServiceWithSalt(cache IdentityCache, salt string) *IdentityService {
return &IdentityService{cache: cache, instanceSalt: salt}
}

View File

@ -1,530 +0,0 @@
package service
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// CascadeSession 代表一个 Cascade Agent 会话
type CascadeSession struct {
ID string
ModelName string
Messages []map[string]interface{} // {role, content}
Metadata map[string]string // 设备指纹、User-Agent 等
Token string // OAuth token
CreatedAt int64
}
// LanguageServerService 业务逻辑层
// 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API
type LanguageServerService struct {
// 会话管理
cascadeSessions map[string]*CascadeSession
sessionMutex sync.RWMutex
// 上游 HTTP 服务(用于发送请求)
httpUpstream HTTPUpstream
// Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新)
antigravitySvc *AntigravityGatewayService
accountRepo AccountRepository
// 日志
logger *slog.Logger
// 改进 1: 速率限制 (令牌桶)
// 限制并发消息处理数量,保护上游 API
rateLimiter chan struct{}
// 改进 3: 会话过期时间 (秒)
sessionTTLSeconds int64
// 改进 3: 定期清理后台任务
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
func NewLanguageServerService(
logger *slog.Logger,
httpUpstream HTTPUpstream,
antigravitySvc *AntigravityGatewayService,
accountRepo AccountRepository,
) *LanguageServerService {
svc := &LanguageServerService{
cascadeSessions: make(map[string]*CascadeSession),
logger: logger,
httpUpstream: httpUpstream,
antigravitySvc: antigravitySvc,
accountRepo: accountRepo,
rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息
sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期
stopCleanup: make(chan struct{}),
}
// 改进 3: 启动后台清理任务
svc.startSessionCleanup()
return svc
}
// startSessionCleanup 启动会话定期清理任务
func (svc *LanguageServerService) startSessionCleanup() {
svc.cleanupTicker = time.NewTicker(1 * time.Minute)
go func() {
for {
select {
case <-svc.cleanupTicker.C:
svc.cleanupExpiredSessions()
case <-svc.stopCleanup:
svc.cleanupTicker.Stop()
return
}
}
}()
}
// cleanupExpiredSessions 清理过期的会话
func (svc *LanguageServerService) cleanupExpiredSessions() {
now := getCurrentTimeMS()
ttlMs := svc.sessionTTLSeconds * 1000
svc.sessionMutex.Lock()
defer svc.sessionMutex.Unlock()
deletedCount := 0
for id, session := range svc.cascadeSessions {
if now-session.CreatedAt > ttlMs {
delete(svc.cascadeSessions, id)
deletedCount++
}
}
if deletedCount > 0 {
svc.logger.Info("expired sessions cleaned up",
"deleted_count", deletedCount,
"remaining_sessions", len(svc.cascadeSessions),
)
}
}
// Stop 优雅关闭服务
func (svc *LanguageServerService) Stop() {
select {
case svc.stopCleanup <- struct{}{}:
default:
}
}
// SetSessionTTL sets the session TTL for testing purposes
func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) {
svc.sessionTTLSeconds = ttlSeconds
}
// GetCascadeSessions returns the current cascade sessions map (for testing)
func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession {
svc.sessionMutex.RLock()
defer svc.sessionMutex.RUnlock()
return svc.cascadeSessions
}
// ============================================================================
// Cascade 业务逻辑
// ============================================================================
// StartCascade 启动新的 Cascade Agent 会话
func (svc *LanguageServerService) StartCascade(
ctx context.Context,
model string,
systemPrompt string,
metadata map[string]string,
token string,
) (string, error) {
// 1. 验证输入
if model == "" {
return "", fmt.Errorf("model is required")
}
if token == "" {
return "", fmt.Errorf("oauth token is required")
}
// 2. 生成会话 ID
sessionID := uuid.New().String()
// 3. 创建会话
session := &CascadeSession{
ID: sessionID,
ModelName: model,
Messages: make([]map[string]interface{}, 0),
Metadata: metadata,
Token: token,
CreatedAt: getCurrentTimeMS(),
}
// 如果提供了系统提示,添加为初始消息
if systemPrompt != "" {
session.Messages = append(session.Messages, map[string]interface{}{
"role": "user",
"content": systemPrompt,
})
}
// 4. 保存会话
svc.sessionMutex.Lock()
svc.cascadeSessions[sessionID] = session
svc.sessionMutex.Unlock()
svc.logger.Info("cascade session started",
"session_id", sessionID,
"model", model,
"has_system_prompt", systemPrompt != "")
return sessionID, nil
}
// SendUserMessage 发送用户消息到 Cascade
// 返回流式更新通道
func (svc *LanguageServerService) SendUserMessage(
ctx context.Context,
cascadeID string,
userMessage string,
token string,
) (<-chan interface{}, error) {
// 改进 1: 获取速率限制令牌
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌,继续
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled")
default:
// 没有令牌,需要等待
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for rate limit")
case <-time.After(30 * time.Second):
return nil, fmt.Errorf("rate limit timeout: too many concurrent messages")
}
}
// 1. 获取会话
svc.sessionMutex.RLock()
session, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.RUnlock()
if !exists {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("cascade session not found: %s", cascadeID)
}
// 2. 验证 token
if token != session.Token {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("invalid token for session")
}
// 改进 2: 并发安全的消息追加(深拷贝消息列表)
svc.sessionMutex.Lock()
newMessages := make([]map[string]interface{}, len(session.Messages)+1)
copy(newMessages, session.Messages)
newMessages[len(newMessages)-1] = map[string]interface{}{
"role": "user",
"content": userMessage,
}
session.Messages = newMessages
svc.sessionMutex.Unlock()
// 4. 创建响应通道
updateChan := make(chan interface{}, 100)
// 5. 启动后台 goroutine 处理 API 调用
go func() {
defer func() {
// 关闭通道
close(updateChan)
// 改进 1: 释放速率限制令牌
<-svc.rateLimiter
}()
// 调用上游 API关键这里需要伪装
svc.callUpstreamAPI(ctx, session, updateChan)
}()
svc.logger.Info("user message sent to cascade",
"session_id", cascadeID,
"message_length", len(userMessage),
"concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数
)
return updateChan, nil
}
// CancelCascade 取消 Cascade 会话
func (svc *LanguageServerService) CancelCascade(
ctx context.Context,
cascadeID string,
) error {
svc.sessionMutex.Lock()
_, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.Unlock()
if !exists {
return fmt.Errorf("cascade session not found: %s", cascadeID)
}
// TODO: 取消正在进行的 API 调用
svc.logger.Info("cascade cancelled", "session_id", cascadeID)
return nil
}
// ============================================================================
// 模型配置
// ============================================================================
// ModelConfig 模型配置
type ModelConfig struct {
Name string
DisplayName string
MaxTokens int
SupportsThinking bool
ThinkingBudget int
SupportsImages bool
Provider string
}
// GetAvailableModels 获取可用模型列表
func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) {
models := []ModelConfig{
{
Name: "claude-opus-4-7",
DisplayName: "Claude Opus 4.7",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 32000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-sonnet-4-7",
DisplayName: "Claude Sonnet 4.7",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 16000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-opus-4-6",
DisplayName: "Claude Opus 4.6",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 32000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-sonnet-4-6",
DisplayName: "Claude Sonnet 4.6",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-haiku-4-5",
DisplayName: "Claude Haiku 4.5",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "gemini-3-pro",
DisplayName: "Gemini 3 Pro",
MaxTokens: 128000,
SupportsThinking: false,
SupportsImages: true,
Provider: "google",
},
}
return models, nil
}
// ============================================================================
// 状态查询
// ============================================================================
// GetStatus 获取服务状态
func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) {
// TODO: 检查上游 API 连接状态
return "running", nil
}
// ============================================================================
// 内部方法
// ============================================================================
// callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。
// 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。
func (svc *LanguageServerService) callUpstreamAPI(
ctx context.Context,
session *CascadeSession,
updateChan chan<- interface{},
) {
if svc.antigravitySvc == nil || svc.accountRepo == nil {
updateChan <- map[string]interface{}{
"type": "error",
"error": "antigravity gateway not configured",
}
return
}
// 1. 选取第一个可用的 Antigravity 账号
accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity)
if err != nil || len(accounts) == 0 {
svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": "no antigravity accounts available",
}
return
}
account := &accounts[0]
// 2. 准备 Claude 格式请求体
requestBody := map[string]interface{}{
"model": session.ModelName,
"messages": session.Messages,
"stream": true,
"max_tokens": 8192,
}
bodyJSON, err := json.Marshal(requestBody)
if err != nil {
svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": "failed to prepare request",
}
return
}
svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID)
// 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试)
respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON)
if err != nil {
svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err)
updateChan <- map[string]interface{}{
"type": "error",
"error": fmt.Sprintf("upstream request failed: %v", err),
}
return
}
defer func() { _ = respBody.Close() }()
// 4. 处理错误响应
if statusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20))
svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body))
updateChan <- map[string]interface{}{
"type": "error",
"status_code": statusCode,
"error": string(body),
}
return
}
// 5. 流式转发响应
svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan)
}
// streamUpstreamResponse 处理上游 SSE 流式响应
func (svc *LanguageServerService) streamUpstreamResponse(
ctx context.Context,
sessionID string,
body io.ReadCloser,
updateChan chan<- interface{},
) {
scanner := bufio.NewScanner(body)
// 设置合理的缓冲区大小
scanner.Buffer(make([]byte, 64*1024), 512*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
svc.logger.Info("streaming cancelled", "session_id", sessionID)
return
default:
}
line := strings.TrimSpace(scanner.Text())
// 跳过空行
if line == "" {
continue
}
// 跳过注释行
if strings.HasPrefix(line, ":") {
continue
}
// 解析 SSE 格式 (data: {...})
if !strings.HasPrefix(line, "data:") {
continue
}
eventData := strings.TrimPrefix(line, "data:")
eventData = strings.TrimSpace(eventData)
// 解析 JSON
var event map[string]interface{}
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
svc.logger.Debug("failed to parse event",
"session_id", sessionID,
"error", err,
"data", eventData,
)
continue
}
// 发送事件到客户端通道
select {
case updateChan <- event:
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
svc.logger.Warn("channel send timeout",
"session_id", sessionID,
)
return
}
}
if err := scanner.Err(); err != nil {
svc.logger.Error("scanning upstream response failed",
"session_id", sessionID,
"error", err,
)
}
}
// getCurrentTimeMS 获取当前时间戳(毫秒)
func getCurrentTimeMS() int64 {
return time.Now().UnixMilli()
}

View File

@ -1,353 +0,0 @@
package service
import (
"context"
"fmt"
"io/fs"
"log/slog"
"net/http"
"os"
"path/filepath"
"time"
connect "connectrpc.com/connect"
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"google.golang.org/protobuf/types/known/timestamppb"
)
const upstreamLSRPCBaseURL = "https://cloudcode-pa.googleapis.com"
// LSRPCHandler implements LanguageServerServiceHandler by proxying to the real upstream
// lsrpc service using OAuth tokens obtained from AntigravityGatewayService.
// File RPCs (ReadFile/WriteFile/ReadDir/etc.) operate on the local filesystem.
type LSRPCHandler struct {
language_server_pbconnect.UnimplementedLanguageServerServiceHandler
antigravitySvc *AntigravityGatewayService
accountRepo AccountRepository
logger *slog.Logger
}
// NewLSRPCHandler creates a new LSRPCHandler.
func NewLSRPCHandler(
antigravitySvc *AntigravityGatewayService,
accountRepo AccountRepository,
logger *slog.Logger,
) *LSRPCHandler {
if logger == nil {
logger = slog.Default()
}
return &LSRPCHandler{
antigravitySvc: antigravitySvc,
accountRepo: accountRepo,
logger: logger,
}
}
// upstreamClient creates a connectrpc client to the real lsrpc upstream,
// authenticated with the OAuth token from the given account.
func (h *LSRPCHandler) upstreamClient(ctx context.Context) (language_server_pbconnect.LanguageServerServiceClient, error) {
accounts, err := h.accountRepo.ListByPlatform(ctx, PlatformAntigravity)
if err != nil || len(accounts) == 0 {
return nil, fmt.Errorf("no antigravity accounts available: %w", err)
}
account := &accounts[0]
tokenProvider := h.antigravitySvc.GetTokenProvider()
if tokenProvider == nil {
return nil, fmt.Errorf("antigravity token provider not configured")
}
accessToken, err := tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("failed to get access token: %w", err)
}
httpClient := &http.Client{
Timeout: 5 * time.Minute,
Transport: &bearerTransport{
base: http.DefaultTransport,
token: accessToken,
},
}
client := language_server_pbconnect.NewLanguageServerServiceClient(
httpClient,
upstreamLSRPCBaseURL,
connect.WithGRPC(),
)
return client, nil
}
// bearerTransport injects Authorization: Bearer <token> into every request.
type bearerTransport struct {
base http.RoundTripper
token string
}
func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
clone := req.Clone(req.Context())
clone.Header.Set("Authorization", "Bearer "+t.token)
return t.base.RoundTrip(clone)
}
// ============================================================================
// Cascade RPCs — proxied to real upstream
// ============================================================================
func (h *LSRPCHandler) StartCascade(
ctx context.Context,
req *connect.Request[language_server_pb.StartCascadeRequest],
) (*connect.Response[language_server_pb.StartCascadeResponse], error) {
client, err := h.upstreamClient(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeUnavailable, err)
}
return client.StartCascade(ctx, req)
}
func (h *LSRPCHandler) SendUserCascadeMessage(
ctx context.Context,
req *connect.Request[language_server_pb.SendUserCascadeMessageRequest],
stream *connect.ServerStream[language_server_pb.CascadeReactiveUpdate],
) error {
client, err := h.upstreamClient(ctx)
if err != nil {
return connect.NewError(connect.CodeUnavailable, err)
}
upstreamStream, err := client.SendUserCascadeMessage(ctx, req)
if err != nil {
return err
}
defer upstreamStream.Close()
for upstreamStream.Receive() {
if err := stream.Send(upstreamStream.Msg()); err != nil {
return err
}
}
return upstreamStream.Err()
}
func (h *LSRPCHandler) CancelCascadeInvocation(
ctx context.Context,
req *connect.Request[language_server_pb.CancelCascadeInvocationRequest],
) (*connect.Response[language_server_pb.CancelCascadeInvocationResponse], error) {
client, err := h.upstreamClient(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeUnavailable, err)
}
return client.CancelCascadeInvocation(ctx, req)
}
func (h *LSRPCHandler) AcknowledgeCascadeCodeEdit(
ctx context.Context,
req *connect.Request[language_server_pb.AcknowledgeCascadeCodeEditRequest],
) (*connect.Response[language_server_pb.AcknowledgeCascadeCodeEditResponse], error) {
client, err := h.upstreamClient(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeUnavailable, err)
}
return client.AcknowledgeCascadeCodeEdit(ctx, req)
}
// ============================================================================
// Model config RPCs — proxied to real upstream
// ============================================================================
func (h *LSRPCHandler) GetCascadeModelConfigs(
ctx context.Context,
req *connect.Request[language_server_pb.GetCascadeModelConfigsRequest],
) (*connect.Response[language_server_pb.GetCascadeModelConfigsResponse], error) {
client, err := h.upstreamClient(ctx)
if err != nil {
// Fall back to static list when upstream unavailable.
return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{
Models: staticCascadeModels(),
}), nil
}
resp, err := client.GetCascadeModelConfigs(ctx, req)
if err != nil {
return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{
Models: staticCascadeModels(),
}), nil
}
return resp, nil
}
func (h *LSRPCHandler) GetCommandModelConfigs(
ctx context.Context,
req *connect.Request[language_server_pb.GetCommandModelConfigsRequest],
) (*connect.Response[language_server_pb.GetCommandModelConfigsResponse], error) {
client, err := h.upstreamClient(ctx)
if err != nil {
return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{
Models: staticCascadeModels(),
}), nil
}
resp, err := client.GetCommandModelConfigs(ctx, req)
if err != nil {
return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{
Models: staticCascadeModels(),
}), nil
}
return resp, nil
}
// staticCascadeModels returns a hard-coded model list as fallback.
func staticCascadeModels() []*language_server_pb.ModelConfig {
return []*language_server_pb.ModelConfig{
{Name: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"},
{Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"},
{Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"},
{Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"},
}
}
// ============================================================================
// File RPCs — local filesystem implementation
// ============================================================================
func (h *LSRPCHandler) ReadFile(
ctx context.Context,
req *connect.Request[language_server_pb.ReadFileRequest],
) (*connect.Response[language_server_pb.ReadFileResponse], error) {
path := req.Msg.GetPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %s", path))
}
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&language_server_pb.ReadFileResponse{
Content: string(data),
}), nil
}
func (h *LSRPCHandler) WriteFile(
ctx context.Context,
req *connect.Request[language_server_pb.WriteFileRequest],
) (*connect.Response[language_server_pb.WriteFileResponse], error) {
path := req.Msg.GetPath()
if req.Msg.GetCreateParent() {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
}
if err := os.WriteFile(path, []byte(req.Msg.GetContent()), 0o644); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&language_server_pb.WriteFileResponse{}), nil
}
func (h *LSRPCHandler) ReadDir(
ctx context.Context,
req *connect.Request[language_server_pb.ReadDirRequest],
) (*connect.Response[language_server_pb.ReadDirResponse], error) {
path := req.Msg.GetPath()
entries, err := os.ReadDir(path)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %s", path))
}
return nil, connect.NewError(connect.CodeInternal, err)
}
files := make([]*language_server_pb.FileInfo, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
continue
}
files = append(files, fileInfoFromOS(entry.Name(), info))
}
return connect.NewResponse(&language_server_pb.ReadDirResponse{
Files: files,
}), nil
}
func (h *LSRPCHandler) DeleteFileOrDirectory(
ctx context.Context,
req *connect.Request[language_server_pb.DeleteFileOrDirectoryRequest],
) (*connect.Response[language_server_pb.DeleteFileOrDirectoryResponse], error) {
path := req.Msg.GetPath()
if err := os.RemoveAll(path); err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&language_server_pb.DeleteFileOrDirectoryResponse{}), nil
}
func (h *LSRPCHandler) StatUri(
ctx context.Context,
req *connect.Request[language_server_pb.StatUriRequest],
) (*connect.Response[language_server_pb.StatUriResponse], error) {
path := req.Msg.GetPath()
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %s", path))
}
return nil, connect.NewError(connect.CodeInternal, err)
}
return connect.NewResponse(&language_server_pb.StatUriResponse{
FileInfo: fileInfoFromOS(info.Name(), info),
}), nil
}
func (h *LSRPCHandler) WatchDirectory(
ctx context.Context,
req *connect.Request[language_server_pb.WatchDirectoryRequest],
stream *connect.ServerStream[language_server_pb.WatchDirectoryResponse],
) error {
// Block until context is cancelled — real FS watching requires fsnotify which
// is not in the dependency graph yet. This satisfies the interface contract
// without crashing; the client will get an EOF when the connection closes.
<-ctx.Done()
return nil
}
// ============================================================================
// Health RPCs
// ============================================================================
func (h *LSRPCHandler) Heartbeat(
ctx context.Context,
req *connect.Request[language_server_pb.HeartbeatRequest],
) (*connect.Response[language_server_pb.HeartbeatResponse], error) {
return connect.NewResponse(&language_server_pb.HeartbeatResponse{
Healthy: true,
Version: "sub2api",
}), nil
}
func (h *LSRPCHandler) GetStatus(
ctx context.Context,
req *connect.Request[language_server_pb.GetStatusRequest],
) (*connect.Response[language_server_pb.GetStatusResponse], error) {
return connect.NewResponse(&language_server_pb.GetStatusResponse{
Status: "running",
Version: antigravity.BaseURL,
}), nil
}
// ============================================================================
// Helpers
// ============================================================================
func fileInfoFromOS(name string, info fs.FileInfo) *language_server_pb.FileInfo {
t := language_server_pb.FileInfo_FILE
if info.IsDir() {
t = language_server_pb.FileInfo_DIRECTORY
} else if info.Mode()&os.ModeSymlink != 0 {
t = language_server_pb.FileInfo_SYMLINK
}
return &language_server_pb.FileInfo{
Path: name,
Type: t,
Size: info.Size(),
ModifiedTime: timestamppb.New(info.ModTime()),
}
}

View File

@ -1,8 +1,10 @@
package service
import "testing"
import (
"testing"
import "github.com/stretchr/testify/require"
"github.com/stretchr/testify/require"
)
func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) {
t.Parallel()

View File

@ -8,8 +8,6 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"image"
"image/color"
stddraw "image/draw"
@ -24,6 +22,9 @@ import (
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
xdraw "golang.org/x/image/draw"
"golang.org/x/sync/singleflight"
)

View File

@ -80,10 +80,10 @@ func TestInjectModelIdentity(t *testing.T) {
wantInjected bool
}{
{
name: "anthropic model without system",
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"},
modelKey: "claude-sonnet-4.6",
name: "anthropic model without system",
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"},
modelKey: "claude-sonnet-4.6",
wantInjected: true,
},
{
@ -111,10 +111,10 @@ func TestInjectModelIdentity(t *testing.T) {
wantInjected: false,
},
{
name: "openai model without system",
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"},
modelKey: "gpt-4o",
name: "openai model without system",
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"},
modelKey: "gpt-4o",
wantInjected: true,
},
}

View File

@ -609,13 +609,13 @@ type windsurfRequestTool struct {
// ---- Helper functions (prefixed to avoid collision with windsurf_gateway_handler.go) ----
type windsurfContentBlock struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input interface{} `json:"input,omitempty"`
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
// Source 来自 Anthropic image block{type:"base64", media_type:"image/png", data:"..."}
Source *windsurfContentImageSource `json:"source,omitempty"`
}

View File

@ -494,7 +494,6 @@ var ProviderSet = wire.NewSet(
NewPaymentService,
ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
ProvideLanguageServerService,
ProvideWindsurfAuthService,
ProvideWindsurfLSService,
ProvideWindsurfChatService,
@ -507,11 +506,6 @@ var ProviderSet = wire.NewSet(
NewChannelMonitorRequestTemplateService,
)
// ProvideLanguageServerService creates LanguageServerService with injected dependencies
func ProvideLanguageServerService(httpUpstream HTTPUpstream, antigravitySvc *AntigravityGatewayService, accountRepo AccountRepository) *LanguageServerService {
return NewLanguageServerService(slog.Default(), httpUpstream, antigravitySvc, accountRepo)
}
// ProvideWindsurfAuthService creates WindsurfAuthService from the main config.
func ProvideWindsurfAuthService(cfg *config.Config, accountRepo AccountRepository, proxyRepo ProxyRepository, adminSvc AdminService) *WindsurfAuthService {
if !cfg.Windsurf.Enabled {