From 898a65314c69195573a5e02bbdaf169613d7db0d Mon Sep 17 00:00:00 2001 From: win Date: Sat, 25 Apr 2026 22:35:48 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=20Antigravity=20?= =?UTF-8?q?=E8=AE=A2=E5=88=B6=E4=BB=A3=E7=A0=81=EF=BC=8C=E5=9B=9E=E9=80=80?= =?UTF-8?q?=E8=87=B3=E4=B8=8A=E6=B8=B8=20v0.1.118?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除自定义文件:gateway_attribution, gateway_claude_runtime_headers, identity_service_antigravity, language_server_service, lsrpc_handler, antigravity_http handler/routes, 所有 antigravity 专项测试 - 将 antigravity pkg/service 文件回退至上游版本(移除 IsEnterprise、 claude_code_tool_map、dynamic fingerprint 等定制逻辑) - 修复 gateway_service.go:移除 NormalizeSystemPromptEnv、 generateSessionIDForAccount、applyClaudeRuntimeOptionalHeaders 调用, 使用上游的 session-id 同步逻辑 - 恢复 language_server_pb gen 文件(Windsurf local_ls.go 依赖) - 保留全部 Windsurf 集成代码不变 --- backend/cmd/server/wire_gen.go | 17 +- backend/cmd/test_antigravity_privacy/main.go | 114 --- backend/cmd/test_antigravity_warmup/main.go | 316 -------- .../language_server_simplified.pb.go | 7 +- .../language_server_simplified.connect.go | 5 +- .../admin/antigravity_oauth_handler.go | 8 +- backend/internal/handler/antigravity_http.go | 267 ------- backend/internal/handler/user_handler_test.go | 22 +- .../pkg/antigravity/claude_code_tool_map.go | 70 -- .../antigravity/claude_code_tool_map_test.go | 160 ---- .../internal/pkg/antigravity/claude_types.go | 20 +- backend/internal/pkg/antigravity/client.go | 55 +- .../internal/pkg/antigravity/gemini_types.go | 3 +- backend/internal/pkg/antigravity/oauth.go | 84 +-- .../pkg/antigravity/oauth_runtime_env_test.go | 19 - .../pkg/antigravity/request_transformer.go | 682 ++---------------- .../antigravity/request_transformer_test.go | 345 +-------- .../pkg/antigravity/response_transformer.go | 13 +- .../pkg/antigravity/stream_transformer.go | 8 +- .../internal/pkg/windsurf/discovery_test.go | 2 +- backend/internal/server/http.go | 4 +- backend/internal/server/router.go | 17 +- .../server/routes/antigravity_http.go | 192 ----- .../server/routes/antigravity_http_test.go | 365 ---------- .../service/antigravity_account68_e2e_test.go | 254 ------- .../antigravity_direct_upstream_test.go | 91 --- .../service/antigravity_gateway_service.go | 328 ++------- .../antigravity_gateway_service_test.go | 114 --- .../service/antigravity_oauth_service.go | 62 +- .../service/antigravity_smart_retry_test.go | 5 +- .../antigravity_test_full_flow_test.go | 187 ----- .../antigravity_test_http_flow_test.go | 188 ----- .../antigravity_test_singleton_test.go | 213 ------ .../antigravity_test_socks5_proxy_test.go | 194 ----- ...gravity_token_provider_requestpath_test.go | 20 - .../internal/service/antigravity_warmup.go | 83 --- .../service/channel_monitor_template_types.go | 3 +- .../internal/service/gateway_attribution.go | 284 -------- .../service/gateway_attribution_test.go | 81 --- .../service/gateway_claude_runtime_headers.go | 47 -- backend/internal/service/gateway_service.go | 55 +- backend/internal/service/identity_service.go | 1 - .../service/identity_service_antigravity.go | 28 - .../service/language_server_service.go | 530 -------------- backend/internal/service/lsrpc_handler.go | 353 --------- .../service/openai_messages_dispatch_test.go | 6 +- backend/internal/service/user_service.go | 5 +- .../service/windsurf_chat_service_test.go | 16 +- .../service/windsurf_gateway_service.go | 14 +- backend/internal/service/wire.go | 6 - 50 files changed, 235 insertions(+), 5728 deletions(-) delete mode 100644 backend/cmd/test_antigravity_privacy/main.go delete mode 100644 backend/cmd/test_antigravity_warmup/main.go delete mode 100644 backend/internal/handler/antigravity_http.go delete mode 100644 backend/internal/pkg/antigravity/claude_code_tool_map.go delete mode 100644 backend/internal/pkg/antigravity/claude_code_tool_map_test.go delete mode 100644 backend/internal/pkg/antigravity/oauth_runtime_env_test.go delete mode 100644 backend/internal/server/routes/antigravity_http.go delete mode 100644 backend/internal/server/routes/antigravity_http_test.go delete mode 100644 backend/internal/service/antigravity_account68_e2e_test.go delete mode 100644 backend/internal/service/antigravity_direct_upstream_test.go delete mode 100644 backend/internal/service/antigravity_test_full_flow_test.go delete mode 100644 backend/internal/service/antigravity_test_http_flow_test.go delete mode 100644 backend/internal/service/antigravity_test_singleton_test.go delete mode 100644 backend/internal/service/antigravity_test_socks5_proxy_test.go delete mode 100644 backend/internal/service/antigravity_token_provider_requestpath_test.go delete mode 100644 backend/internal/service/antigravity_warmup.go delete mode 100644 backend/internal/service/gateway_attribution.go delete mode 100644 backend/internal/service/gateway_attribution_test.go delete mode 100644 backend/internal/service/gateway_claude_runtime_headers.go delete mode 100644 backend/internal/service/identity_service_antigravity.go delete mode 100644 backend/internal/service/language_server_service.go delete mode 100644 backend/internal/service/lsrpc_handler.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 710ebb58..0991af04 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -8,6 +8,11 @@ package main import ( "context" + "log" + "net/http" + "sync" + "time" + "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -18,14 +23,9 @@ import ( "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" - "log" - "net/http" - "sync" - "time" -) -import ( _ "embed" + _ "github.com/Wei-Shaw/sub2api/ent/runtime" ) @@ -257,9 +257,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) - langServerService := service.ProvideLanguageServerService(httpUpstream, antigravityGatewayService, accountRepository) - lsrpcHandler := service.NewLSRPCHandler(antigravityGatewayService, accountRepository, nil) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, langServerService, lsrpcHandler) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient) httpServer := server.ProvideHTTPServer(configConfig, engine) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) @@ -271,7 +269,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) -<<<<<<< HEAD channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService) application := &Application{ diff --git a/backend/cmd/test_antigravity_privacy/main.go b/backend/cmd/test_antigravity_privacy/main.go deleted file mode 100644 index 58c6c891..00000000 --- a/backend/cmd/test_antigravity_privacy/main.go +++ /dev/null @@ -1,114 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "log" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -func repeatStr(s string, count int) string { - return strings.Repeat(s, count) -} - -func main() { - accessToken := flag.String("token", "", "OAuth access token") - projectID := flag.String("project", "", "Project ID") - proxyURL := flag.String("proxy", "", "Proxy URL (optional)") - flag.Parse() - - if *accessToken == "" || *projectID == "" { - log.Fatal("missing required flags: -token and -project") - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - client, err := antigravity.NewClient(*proxyURL) - if err != nil { - log.Fatalf("failed to create client: %v", err) - } - - fmt.Println(repeatStr("=", 80)) - fmt.Println("Antigravity Privacy Setup Diagnostic Test") - fmt.Println(repeatStr("=", 80)) - - // Step 1: Verify token is valid by fetching user info - fmt.Println("\n[Step 1] Verifying access token...") - userInfo, err := client.GetUserInfo(ctx, *accessToken) - if err != nil { - log.Fatalf("failed to get user info: %v", err) - } - fmt.Printf("✓ Email: %s\n", userInfo.Email) - - // Step 2: Call SetUserSettings - fmt.Println("\n[Step 2] Calling SetUserSettings (clear privacy settings)...") - setResp, err := client.SetUserSettings(ctx, *accessToken) - if err != nil { - log.Fatalf("SetUserSettings failed: %v", err) - } - - if setResp.IsSuccess() { - fmt.Println("✓ SetUserSettings succeeded") - fmt.Printf(" Response: %+v\n", setResp) - } else { - fmt.Println("✗ SetUserSettings returned non-empty userSettings") - fmt.Printf(" Response: %+v\n", setResp) - fmt.Println("\n ERROR: This indicates privacy settings were NOT cleared!") - fmt.Println(" Possible causes:") - fmt.Println(" 1. Account restrictions on privacy settings") - fmt.Println(" 2. Account still has telemetryEnabled=true") - fmt.Println(" 3. API response indicates settings persist") - } - - // Step 3: Verify by calling FetchUserInfo - fmt.Println("\n[Step 3] Calling FetchUserInfo to verify privacy status...") - userInfoResp, err := client.FetchUserInfo(ctx, *accessToken, *projectID) - if err != nil { - log.Fatalf("FetchUserInfo failed: %v", err) - } - - if userInfoResp.IsPrivate() { - fmt.Println("✓ Privacy is properly set (userSettings is empty)") - fmt.Printf(" Response: %+v\n", userInfoResp) - } else { - fmt.Println("✗ Privacy is NOT properly set (userSettings contains telemetryEnabled)") - fmt.Printf(" Response: %+v\n", userInfoResp) - fmt.Println("\n ERROR: This explains the 503 errors in gateway!") - fmt.Println(" Reason: Antigravity API rejects requests from accounts with") - fmt.Println(" telemetryEnabled=true to protect user privacy") - } - - // Summary - fmt.Println("\n" + repeatStr("=", 80)) - fmt.Println("DIAGNOSIS SUMMARY") - fmt.Println(repeatStr("=", 80)) - - if setResp.IsSuccess() && userInfoResp.IsPrivate() { - fmt.Println("✓ Privacy setup is SUCCESSFUL") - fmt.Println(" This account should NOT experience 503 errors due to privacy") - fmt.Println(" The 503 errors might be due to:") - fmt.Println(" 1. Temporary API outages") - fmt.Println(" 2. Rate limiting on new accounts") - fmt.Println(" 3. Other infrastructure issues") - } else if !setResp.IsSuccess() && !userInfoResp.IsPrivate() { - fmt.Println("✗ Privacy setup FAILED") - fmt.Println(" The account cannot clear privacy settings on Antigravity") - fmt.Println(" This causes the 503 Service Unavailable errors") - fmt.Println("\nSOLUTION:") - fmt.Println(" 1. Check if this is a restricted account type") - fmt.Println(" 2. Try re-authorizing the account") - fmt.Println(" 3. Check Antigravity API rate limiting") - fmt.Println(" 4. Inspect firewall/proxy settings") - } else { - fmt.Println("⚠ INCONSISTENT STATE:") - fmt.Println(" SetUserSettings and FetchUserInfo returned different results") - fmt.Println(" This might indicate a transient API issue or data sync delay") - } - - fmt.Println("\n" + repeatStr("=", 80)) -} diff --git a/backend/cmd/test_antigravity_warmup/main.go b/backend/cmd/test_antigravity_warmup/main.go deleted file mode 100644 index b8d2f466..00000000 --- a/backend/cmd/test_antigravity_warmup/main.go +++ /dev/null @@ -1,316 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "log" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -// TestScenario 定义一个测试场景 -type TestScenario struct { - name string - description string - testFunc func(ctx context.Context, token, projectID string) (bool, string) -} - -var scenarios []TestScenario - -func init() { - scenarios = []TestScenario{ - { - name: "single_request", - description: "单次请求 - 检查是否立即成功", - testFunc: testSingleRequest, - }, - { - name: "sequential_requests", - description: "顺序发送 10 个请求 - 找到稳定点", - testFunc: testSequentialRequests, - }, - { - name: "concurrent_requests", - description: "并发发送 5 个请求 - 检查并发初始化行为", - testFunc: testConcurrentRequests, - }, - { - name: "warmup_then_request", - description: "预热(模型列表请求) + 业务请求 - 验证预热效果", - testFunc: testWarmupThenRequest, - }, - { - name: "delayed_request", - description: "延迟 5 秒后请求 - 检查账号初始化时间", - testFunc: testDelayedRequest, - }, - } -} - -// testSingleRequest 单次请求 -func testSingleRequest(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - start := time.Now() - resp, _, err := client.FetchAvailableModels(ctx, token, projectID) - elapsed := time.Since(start) - - if err != nil { - return false, fmt.Sprintf("请求失败 (%v): %v", elapsed, err) - } - - if resp == nil { - return false, fmt.Sprintf("响应为空 (%v)", elapsed) - } - - return true, fmt.Sprintf("✓ 单次请求成功 - 耗时 %v", elapsed) -} - -// testSequentialRequests 顺序发送多个请求 -func testSequentialRequests(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - var firstFailIdx = -1 - var firstSuccessIdx = -1 - var timings []time.Duration - - for i := 0; i < 10; i++ { - start := time.Now() - resp, _, err := client.FetchAvailableModels(ctx, token, projectID) - elapsed := time.Since(start) - timings = append(timings, elapsed) - - success := err == nil && resp != nil - fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", i+1, elapsed, map[bool]string{true: "✓", false: "✗"}[success]) - - if !success && firstFailIdx == -1 { - firstFailIdx = i - } - if success && firstSuccessIdx == -1 { - firstSuccessIdx = i - } - } - - var report string - if firstSuccessIdx == -1 { - report = "✗ 全部失败" - } else if firstSuccessIdx == 0 { - report = fmt.Sprintf("✓ 首次即成功 (耗时 %v)", timings[0]) - } else { - report = fmt.Sprintf("⚠ 第 %d 次才成功 (失败 %d 次), 首次耗时 %v", - firstSuccessIdx+1, firstSuccessIdx, timings[firstSuccessIdx]) - } - - return firstSuccessIdx >= 0, report -} - -// testConcurrentRequests 并发请求 -func testConcurrentRequests(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - var wg sync.WaitGroup - results := make([]bool, 5) - timings := make([]time.Duration, 5) - mu := sync.Mutex{} - - for i := 0; i < 5; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - start := time.Now() - resp, _, err := client.FetchAvailableModels(ctx, token, projectID) - elapsed := time.Since(start) - - mu.Lock() - results[idx] = err == nil && resp != nil - timings[idx] = elapsed - mu.Unlock() - - fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", idx+1, elapsed, map[bool]string{true: "✓", false: "✗"}[results[idx]]) - }(i) - } - - wg.Wait() - - successCount := 0 - for _, ok := range results { - if ok { - successCount++ - } - } - - return successCount > 0, fmt.Sprintf("%d/5 并发请求成功", successCount) -} - -// testWarmupThenRequest 预热测试 -func testWarmupThenRequest(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - // 第 1 步:预热 - 调用 LoadCodeAssist(获取项目信息) - fmt.Println(" [Warmup] 调用 LoadCodeAssist 预热...") - warmupStart := time.Now() - _, _, warmupErr := client.LoadCodeAssist(ctx, token) - warmupElapsed := time.Since(warmupStart) - fmt.Printf(" [Warmup] 耗时 %v, 状态: %v\n", warmupElapsed, map[bool]string{true: "✓", false: "✗"}[warmupErr == nil]) - - // 第 2 步:实际请求 - fmt.Println(" [Request] 发送业务请求...") - reqStart := time.Now() - resp, _, err := client.FetchAvailableModels(ctx, token, projectID) - reqElapsed := time.Since(reqStart) - success := err == nil && resp != nil - fmt.Printf(" [Request] 耗时 %v, 状态: %v\n", reqElapsed, map[bool]string{true: "✓", false: "✗"}[success]) - - return success, fmt.Sprintf("预热 %v + 请求 %v = 总耗时 %v", - warmupElapsed, reqElapsed, warmupElapsed+reqElapsed) -} - -// testDelayedRequest 延迟请求 -func testDelayedRequest(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - fmt.Println(" 等待 5 秒...") - time.Sleep(5 * time.Second) - - start := time.Now() - resp, _, err := client.FetchAvailableModels(ctx, token, projectID) - elapsed := time.Since(start) - - success := err == nil && resp != nil - return success, fmt.Sprintf("延迟 5s 后请求 - 耗时 %v, 状态: %v", elapsed, map[bool]string{true: "✓", false: "✗"}[success]) -} - -// testOAuthTokenRefresh OAuth Token 刷新测试 -func testOAuthTokenRefresh(ctx context.Context, refreshToken string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - start := time.Now() - tokenInfo, err := client.RefreshToken(ctx, refreshToken, false) - elapsed := time.Since(start) - - if err != nil { - return false, fmt.Sprintf("Token 刷新失败 (%v): %v", elapsed, err) - } - - return true, fmt.Sprintf("✓ Token 刷新成功 - 耗时 %v, 新 Token 有效期: %d 秒", - elapsed, tokenInfo.ExpiresIn) -} - -// testAccountInitializationWarmup 账号初始化预热 -func testAccountInitializationWarmup(ctx context.Context, token, projectID string) (bool, string) { - client, err := antigravity.NewClient("") - if err != nil { - return false, fmt.Sprintf("创建客户端失败: %v", err) - } - - fmt.Println(" 执行完整的账号初始化流程...") - - // 1. GetUserInfo - fmt.Println(" 1. GetUserInfo...") - start := time.Now() - _, err1 := client.GetUserInfo(ctx, token) - fmt.Printf(" 耗时: %v\n", time.Since(start)) - - // 2. LoadCodeAssist - fmt.Println(" 2. LoadCodeAssist...") - start = time.Now() - _, _, err2 := client.LoadCodeAssist(ctx, token) - fmt.Printf(" 耗时: %v\n", time.Since(start)) - - // 3. FetchAvailableModels - fmt.Println(" 3. FetchAvailableModels...") - start = time.Now() - _, _, err3 := client.FetchAvailableModels(ctx, token, projectID) - elapsed := time.Since(start) - fmt.Printf(" 耗时: %v\n", elapsed) - - success := err1 == nil && err2 == nil && err3 == nil - return success, fmt.Sprintf("账号初始化预热 - 状态: %v", map[bool]string{true: "✓", false: "✗"}[success]) -} - -func main() { - accessToken := flag.String("token", "", "OAuth access token") - projectID := flag.String("project", "", "Project ID") - refreshToken := flag.String("refresh", "", "Refresh token (optional)") - testName := flag.String("test", "all", "测试名称 (all, single_request, sequential_requests, etc.)") - flag.Parse() - - if *accessToken == "" || *projectID == "" { - log.Fatal("缺少必需参数: -token 和 -project") - } - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - fmt.Println("\n" + repeatStr("=", 80)) - fmt.Println("Antigravity 账号初始化诊断测试套件") - fmt.Println(repeatStr("=", 80) + "\n") - - // Token 刷新测试 - if *refreshToken != "" { - fmt.Println("[Token 刷新测试]") - _, report := testOAuthTokenRefresh(ctx, *refreshToken) - fmt.Printf("%s\n\n", report) - } - - // 账号初始化预热测试 - fmt.Println("[账号初始化预热]") - _, report := testAccountInitializationWarmup(ctx, *accessToken, *projectID) - fmt.Printf("%s\n\n", report) - - // 运行指定的测试 - if *testName == "all" { - for _, scenario := range scenarios { - fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description) - _, report := scenario.testFunc(ctx, *accessToken, *projectID) - fmt.Printf("结果: %s\n\n", report) - } - } else { - found := false - for _, scenario := range scenarios { - if scenario.name == *testName { - found = true - fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description) - _, report := scenario.testFunc(ctx, *accessToken, *projectID) - fmt.Printf("结果: %s\n\n", report) - break - } - } - if !found { - log.Fatalf("未找到测试: %s", *testName) - } - } - - fmt.Println(repeatStr("=", 80)) - fmt.Println("诊断完成") - fmt.Println(repeatStr("=", 80)) -} - -func repeatStr(s string, count int) string { - result := "" - for i := 0; i < count; i++ { - result += s - } - return result -} diff --git a/backend/internal/gen/language_server_pb/language_server_simplified.pb.go b/backend/internal/gen/language_server_pb/language_server_simplified.pb.go index a8b8ffc4..83183f5b 100644 --- a/backend/internal/gen/language_server_pb/language_server_simplified.pb.go +++ b/backend/internal/gen/language_server_pb/language_server_simplified.pb.go @@ -7,13 +7,14 @@ package language_server_pb import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" _ "google.golang.org/protobuf/types/known/emptypb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" - reflect "reflect" - sync "sync" - unsafe "unsafe" ) const ( diff --git a/backend/internal/gen/language_server_pbconnect/language_server_simplified.connect.go b/backend/internal/gen/language_server_pbconnect/language_server_simplified.connect.go index 10999497..a8e07231 100644 --- a/backend/internal/gen/language_server_pbconnect/language_server_simplified.connect.go +++ b/backend/internal/gen/language_server_pbconnect/language_server_simplified.connect.go @@ -5,12 +5,13 @@ package language_server_pbconnect import ( - connect "connectrpc.com/connect" context "context" errors "errors" - language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb" http "net/http" strings "strings" + + connect "connectrpc.com/connect" + language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb" ) // This is a compile-time assertion to ensure that this generated file and the connect package are diff --git a/backend/internal/handler/admin/antigravity_oauth_handler.go b/backend/internal/handler/admin/antigravity_oauth_handler.go index 8fcb148b..7488965d 100644 --- a/backend/internal/handler/admin/antigravity_oauth_handler.go +++ b/backend/internal/handler/admin/antigravity_oauth_handler.go @@ -15,8 +15,7 @@ func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAut } type AntigravityGenerateAuthURLRequest struct { - ProxyID *int64 `json:"proxy_id"` - IsEnterprise bool `json:"is_enterprise"` + ProxyID *int64 `json:"proxy_id"` } // GenerateAuthURL generates Google OAuth authorization URL @@ -28,7 +27,7 @@ func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) { return } - result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.IsEnterprise) + result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID) if err != nil { response.InternalError(c, "生成授权链接失败: "+err.Error()) return @@ -71,7 +70,6 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) { type AntigravityRefreshTokenRequest struct { RefreshToken string `json:"refresh_token" binding:"required"` ProxyID *int64 `json:"proxy_id"` - IsEnterprise bool `json:"is_enterprise"` } // RefreshToken validates an Antigravity refresh token and returns full token info @@ -83,7 +81,7 @@ func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) { return } - tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID, req.IsEnterprise) + tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/antigravity_http.go b/backend/internal/handler/antigravity_http.go deleted file mode 100644 index 5b186712..00000000 --- a/backend/internal/handler/antigravity_http.go +++ /dev/null @@ -1,267 +0,0 @@ -package handler - -import ( - "encoding/json" - "log/slog" - "net/http" - - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" -) - -// AntigravityHTTPHandler 处理下游客户端的 HTTP 请求 -// 内部调用 LanguageServerService,再转发到上游 API -type AntigravityHTTPHandler struct { - langServerService *service.LanguageServerService - logger *slog.Logger -} - -func NewAntigravityHTTPHandler( - langServerService *service.LanguageServerService, - logger *slog.Logger, -) *AntigravityHTTPHandler { - return &AntigravityHTTPHandler{ - langServerService: langServerService, - logger: logger, - } -} - -// ============================================================================ -// Cascade 流程 API -// ============================================================================ - -// StartCascadeRequest HTTP 请求格式 -type StartCascadeRequest struct { - Model string `json:"model"` // 模型名称 - SystemPrompt string `json:"system_prompt"` // 系统提示 - Metadata map[string]string `json:"metadata"` // 设备指纹等伪装信息 -} - -// StartCascadeResponse HTTP 响应格式 -type StartCascadeResponse struct { - CascadeID string `json:"cascade_id"` -} - -// POST /api/v1/cascade/start -// 启动新的 Cascade Agent 会话 -func (h *AntigravityHTTPHandler) StartCascade(c *gin.Context) { - var req StartCascadeRequest - if err := c.ShouldBindJSON(&req); err != nil { - h.logger.Error("invalid request", "error", err) - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid request: " + err.Error(), - }) - return - } - - // 提取 OAuth token - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "missing authorization header", - }) - return - } - - // 调用内部 LanguageServerService - cascadeID, err := h.langServerService.StartCascade( - c.Request.Context(), - req.Model, - req.SystemPrompt, - req.Metadata, - token, - ) - if err != nil { - h.logger.Error("start cascade failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - return - } - - h.logger.Info("cascade started", "cascade_id", cascadeID, "model", req.Model) - - c.JSON(http.StatusOK, StartCascadeResponse{ - CascadeID: cascadeID, - }) -} - -// ============================================================================ - -// SendUserMessageRequest HTTP 请求格式 -type SendUserMessageRequest struct { - CascadeID string `json:"cascade_id"` // 会话 ID - Message string `json:"message"` // 用户消息 - Context map[string]string `json:"context"` // 上下文(可选) -} - -// CascadeUpdate 流式响应格式(Server-Sent Events) -type CascadeUpdate struct { - Type string `json:"type"` // "message_delta", "tool_call", etc. - Payload string `json:"payload"` // JSON 格式的负载 -} - -// POST /api/v1/cascade/message (流式) -// 发送用户消息,接收流式更新 -func (h *AntigravityHTTPHandler) SendUserMessage(c *gin.Context) { - var req SendUserMessageRequest - if err := c.ShouldBindJSON(&req); err != nil { - h.logger.Error("invalid request", "error", err) - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid request: " + err.Error(), - }) - return - } - - // 提取 OAuth token - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "missing authorization header", - }) - return - } - - // 设置 Server-Sent Events 响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - - // 调用内部 LanguageServerService,获取流式响应 - updateChan, err := h.langServerService.SendUserMessage( - c.Request.Context(), - req.CascadeID, - req.Message, - token, - ) - if err != nil { - h.logger.Error("send user message failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - return - } - - // 逐个推送更新到客户端(SSE) - for update := range updateChan { - data, _ := json.Marshal(update) - c.SSEvent("update", string(data)) - c.Writer.Flush() - } - - h.logger.Info("cascade message processed", "cascade_id", req.CascadeID) -} - -// ============================================================================ - -// POST /api/v1/cascade/cancel -// 取消 Cascade 调用 -func (h *AntigravityHTTPHandler) CancelCascade(c *gin.Context) { - var req struct { - CascadeID string `json:"cascade_id"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid request", - }) - return - } - - if err := h.langServerService.CancelCascade( - c.Request.Context(), - req.CascadeID, - ); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - return - } - - h.logger.Info("cascade cancelled", "cascade_id", req.CascadeID) - c.JSON(http.StatusOK, gin.H{ - "success": true, - }) -} - -// ============================================================================ -// 模型配置 API -// ============================================================================ - -// ModelConfig 模型配置 -type ModelConfig struct { - Name string `json:"name"` - DisplayName string `json:"display_name"` - MaxTokens int `json:"max_tokens"` - SupportsThinking bool `json:"supports_thinking"` - ThinkingBudget int `json:"thinking_budget,omitempty"` - SupportsImages bool `json:"supports_images"` - Provider string `json:"provider"` // anthropic, google, openai -} - -// GET /api/v1/models -// 获取可用模型列表 -func (h *AntigravityHTTPHandler) GetModels(c *gin.Context) { - models, err := h.langServerService.GetAvailableModels(c.Request.Context()) - if err != nil { - h.logger.Error("get models failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "models": models, - "default_model": "claude-opus-4-7", - }) -} - -// ============================================================================ -// 健康检查 API -// ============================================================================ - -// GET /api/v1/health -// 健康检查 -func (h *AntigravityHTTPHandler) Health(c *gin.Context) { - status, err := h.langServerService.GetStatus(c.Request.Context()) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "status": "unhealthy", - "error": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "status": status, - "version": "1.0.0", - }) -} - -// ============================================================================ - -// RegisterRoutes 注册所有 HTTP 路由 -func (h *AntigravityHTTPHandler) RegisterRoutes(router *gin.Engine) { - api := router.Group("/api/v1") - - // Cascade 流程 - api.POST("/cascade/start", h.StartCascade) - api.POST("/cascade/message", h.SendUserMessage) - api.POST("/cascade/cancel", h.CancelCascade) - - // 模型列表 - api.GET("/models", h.GetModels) - - // 健康检查 - api.GET("/health", h.Health) - - h.logger.Info("antigravity http handler registered", - "routes", []string{ - "/api/v1/cascade/start", - "/api/v1/cascade/message", - "/api/v1/cascade/cancel", - "/api/v1/models", - "/api/v1/health", - }) -} diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 8a864b51..192ca1f6 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -270,19 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { AvatarURL: "https://cdn.example.com/linuxdo.png", AvatarSource: "remote_url", }, - identities: []service.UserAuthIdentityRecord{ - { - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: "linuxdo-subject-21", - VerifiedAt: &verifiedAt, - Metadata: map[string]any{ - "username": "linuxdo-handle", - "avatar_url": "https://cdn.example.com/linuxdo.png", - }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", }, }, - } + }, + } handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() diff --git a/backend/internal/pkg/antigravity/claude_code_tool_map.go b/backend/internal/pkg/antigravity/claude_code_tool_map.go deleted file mode 100644 index 155ffeba..00000000 --- a/backend/internal/pkg/antigravity/claude_code_tool_map.go +++ /dev/null @@ -1,70 +0,0 @@ -package antigravity - -import "strings" - -var claudeCodeBuiltinToolNameMap = map[string]string{ - "read": "Read", - "read_file": "Read", - "readfile": "Read", - "write": "Write", - "write_file": "Write", - "writefile": "Write", - "edit": "Edit", - "apply_patch": "Edit", - "applypatch": "Edit", - "bash": "Bash", - "execute_bash": "Bash", - "executebash": "Bash", - "exec_bash": "Bash", - "execbash": "Bash", - "glob": "Glob", - "list_files": "Glob", - "listfiles": "Glob", - "grep": "Grep", - "search_files": "Grep", - "searchfiles": "Grep", - "webfetch": "WebFetch", - "web_fetch": "WebFetch", - "fetch": "WebFetch", - "websearch": "WebSearch", - "web_search": "WebSearch", - "agent": "Agent", - "askuserquestion": "AskUserQuestion", - "ask_user_question": "AskUserQuestion", - "enterplanmode": "EnterPlanMode", - "enter_plan_mode": "EnterPlanMode", - "exitplanmode": "ExitPlanMode", - "exit_plan_mode": "ExitPlanMode", - "croncreate": "CronCreate", - "cron_create": "CronCreate", - "crondelete": "CronDelete", - "cron_delete": "CronDelete", - "schedulewakeup": "ScheduleWakeup", - "schedule_wakeup": "ScheduleWakeup", - "sendmessage": "SendMessage", - "send_message": "SendMessage", - "skill": "Skill", - "taskcreate": "TaskCreate", - "task_create": "TaskCreate", - "tasklist": "TaskList", - "task_list": "TaskList", - "taskoutput": "TaskOutput", - "task_output": "TaskOutput", - "taskstop": "TaskStop", - "task_stop": "TaskStop", - "taskupdate": "TaskUpdate", - "task_update": "TaskUpdate", -} - -func normalizeClaudeCodeToolName(name string) string { - trimmed := strings.TrimSpace(name) - if trimmed == "" { - return "" - } - - if mapped, ok := claudeCodeBuiltinToolNameMap[strings.ToLower(trimmed)]; ok { - return mapped - } - - return trimmed -} diff --git a/backend/internal/pkg/antigravity/claude_code_tool_map_test.go b/backend/internal/pkg/antigravity/claude_code_tool_map_test.go deleted file mode 100644 index 9e9beae4..00000000 --- a/backend/internal/pkg/antigravity/claude_code_tool_map_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package antigravity - -import ( - "encoding/json" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestNormalizeClaudeCodeToolName(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expected string - }{ - {name: "read alias", input: "read_file", expected: "Read"}, - {name: "grep alias", input: "search_files", expected: "Grep"}, - {name: "webfetch alias", input: "fetch", expected: "WebFetch"}, - {name: "plan alias", input: "enter_plan_mode", expected: "EnterPlanMode"}, - {name: "native passthrough", input: "TaskUpdate", expected: "TaskUpdate"}, - {name: "mcp passthrough", input: "mcp__github__list_prs", expected: "mcp__github__list_prs"}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.expected, normalizeClaudeCodeToolName(tt.input)) - }) - } -} - -func TestBuildPartsNormalizesClaudeCodeToolNames(t *testing.T) { - t.Parallel() - - toolIDToName := make(map[string]string) - assistantParts, stripped, err := buildParts(json.RawMessage(`[ - {"type":"tool_use","id":"tool-1","name":"read_file","input":{"file_path":"/tmp/demo.txt"}} - ]`), toolIDToName, false) - require.NoError(t, err) - require.False(t, stripped) - require.Len(t, assistantParts, 1) - require.NotNil(t, assistantParts[0].FunctionCall) - require.Equal(t, "Read", assistantParts[0].FunctionCall.Name) - require.Equal(t, "Read", toolIDToName["tool-1"]) - - userParts, stripped, err := buildParts(json.RawMessage(`[ - {"type":"tool_result","tool_use_id":"tool-1","content":[{"type":"text","text":"ok"}]} - ]`), toolIDToName, false) - require.NoError(t, err) - require.False(t, stripped) - require.Len(t, userParts, 1) - require.NotNil(t, userParts[0].FunctionResponse) - require.Equal(t, "Read", userParts[0].FunctionResponse.Name) -} - -func TestBuildToolsNormalizesClaudeCodeBuiltinNamesOnly(t *testing.T) { - t.Parallel() - - result := buildTools([]ClaudeTool{ - { - Name: "search_files", - Description: "Search the workspace", - InputSchema: map[string]any{ - "type": "object", - }, - }, - { - Name: "mcp__github__list_prs", - Description: "List pull requests", - InputSchema: map[string]any{ - "type": "object", - }, - }, - }) - - require.Len(t, result, 1) - require.Len(t, result[0].FunctionDeclarations, 2) - require.Equal(t, "Grep", result[0].FunctionDeclarations[0].Name) - require.Equal(t, "mcp__github__list_prs", result[0].FunctionDeclarations[1].Name) -} - -func TestNonStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) { - t.Parallel() - - processor := NewNonStreamingProcessor() - response := processor.Process(&GeminiResponse{ - Candidates: []GeminiCandidate{ - { - Content: &GeminiContent{ - Parts: []GeminiPart{ - { - FunctionCall: &GeminiFunctionCall{ - Name: "web_fetch", - Args: map[string]any{"url": "https://example.com"}, - }, - }, - }, - }, - FinishReason: "STOP", - }, - }, - }, "resp-1", "claude-sonnet-4-5") - - require.Len(t, response.Content, 1) - require.Equal(t, "tool_use", response.Content[0].Type) - require.Equal(t, "WebFetch", response.Content[0].Name) - require.True(t, strings.HasPrefix(response.Content[0].ID, "WebFetch-")) - require.NotNil(t, response.Content[0].Caller) - require.Equal(t, "direct", response.Content[0].Caller.Type) - require.Equal(t, "tool_use", response.StopReason) -} - -func TestStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) { - t.Parallel() - - processor := NewStreamingProcessor("claude-sonnet-4-5") - output := processor.processFunctionCall(&GeminiFunctionCall{ - Name: "search_files", - Args: map[string]any{"pattern": "TODO"}, - }, "") - - events := parseSSEDataEvents(t, output) - require.Len(t, events, 3) - - contentBlock, ok := events[0]["content_block"].(map[string]any) - require.True(t, ok) - require.Equal(t, "tool_use", contentBlock["type"]) - require.Equal(t, "Grep", contentBlock["name"]) - - toolID, ok := contentBlock["id"].(string) - require.True(t, ok) - require.True(t, strings.HasPrefix(toolID, "Grep-")) - - caller, ok := contentBlock["caller"].(map[string]any) - require.True(t, ok) - require.Equal(t, "direct", caller["type"]) -} - -func parseSSEDataEvents(t *testing.T, payload []byte) []map[string]any { - t.Helper() - - lines := strings.Split(string(payload), "\n") - events := make([]map[string]any, 0) - - for _, line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - - var event map[string]any - require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &event)) - events = append(events, event) - } - - return events -} diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 89b8fab9..0b8ae5f2 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -16,7 +16,6 @@ type ClaudeRequest struct { TopK *int `json:"top_k,omitempty"` Tools []ClaudeTool `json:"tools,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` Metadata *ClaudeMetadata `json:"metadata,omitempty"` } @@ -73,10 +72,9 @@ type ContentBlock struct { Thinking string `json:"thinking,omitempty"` Signature string `json:"signature,omitempty"` // tool_use - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` - Caller *ToolCaller `json:"caller,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` // tool_result ToolUseID string `json:"tool_use_id,omitempty"` Content json.RawMessage `json:"content,omitempty"` @@ -116,15 +114,9 @@ type ClaudeContentItem struct { Signature string `json:"signature,omitempty"` // tool_use - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` - Caller *ToolCaller `json:"caller,omitempty"` -} - -// ToolCaller Claude Code tool_use 调用来源 -type ToolCaller struct { - Type string `json:"type"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` } // ClaudeUsage Claude 用量统计 diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 882a0cda..fdd7fea1 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -318,17 +318,16 @@ func shouldFallbackToNextURL(err error, statusCode int) bool { statusCode >= 500 } -// ExchangeCode 用 authorization code 交换 token。 -// isEnterprise=true 时使用企业 OAuth client_id/secret。 -func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, isEnterprise bool) (*TokenResponse, error) { - creds, err := GetClientCredentials(isEnterprise) +// ExchangeCode 用 authorization code 交换 token +func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() if err != nil { return nil, err } params := url.Values{} - params.Set("client_id", creds.ClientID) - params.Set("client_secret", creds.ClientSecret) + params.Set("client_id", ClientID) + params.Set("client_secret", clientSecret) params.Set("code", code) params.Set("redirect_uri", RedirectURI) params.Set("grant_type", "authorization_code") @@ -363,17 +362,16 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, is return &tokenResp, nil } -// RefreshToken 刷新 access_token。 -// isEnterprise=true 时使用企业 OAuth client_id/secret。 -func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterprise bool) (*TokenResponse, error) { - creds, err := GetClientCredentials(isEnterprise) +// RefreshToken 刷新 access_token +func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() if err != nil { return nil, err } params := url.Values{} - params.Set("client_id", creds.ClientID) - params.Set("client_secret", creds.ClientSecret) + params.Set("client_id", ClientID) + params.Set("client_secret", clientSecret) params.Set("refresh_token", refreshToken) params.Set("grant_type", "refresh_token") @@ -406,39 +404,6 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterp return &tokenResp, nil } -// RefreshTokenAuto 自动判定账号类型。 -// 先用个人凭证刷新;若 Google 返回 invalid_client/unauthorized_client(client 不匹配), -// 再用企业凭证重试。返回 token 和最终判定的 isEnterprise 标志。 -// -// 其他错误(invalid_grant、网络错误等)直接返回,不重试。 -func (c *Client) RefreshTokenAuto(ctx context.Context, refreshToken string) (*TokenResponse, bool, error) { - tok, err := c.RefreshToken(ctx, refreshToken, false) - if err == nil { - return tok, false, nil - } - if !isClientMismatchError(err) { - return nil, false, err - } - tok, err2 := c.RefreshToken(ctx, refreshToken, true) - if err2 == nil { - return tok, true, nil - } - // 企业也失败:返回合并后的诊断错误 - return nil, false, fmt.Errorf("auto-detect refresh failed: personal=%v enterprise=%v", err, err2) -} - -// isClientMismatchError 判断是否为 OAuth client 不匹配导致的错误。 -// 只有这种错误才会触发"切换账号类型重试"。 -func isClientMismatchError(err error) bool { - if err == nil { - return false - } - msg := err.Error() - return strings.Contains(msg, "invalid_client") || - strings.Contains(msg, "unauthorized_client") || - strings.Contains(msg, "client_id") -} - // GetUserInfo 获取用户信息 func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil) diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 3ed149b9..033dccbd 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -117,8 +117,7 @@ type GeminiToolConfig struct { // GeminiFunctionCallingConfig 函数调用配置 type GeminiFunctionCallingConfig struct { - Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE, ANY - AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"` + Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE } // GeminiSafetySetting Gemini 安全设置 diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 360b7a4e..7c963d9e 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "os" - "runtime" "strings" "sync" "time" @@ -23,22 +22,16 @@ const ( TokenURL = "https://oauth2.googleapis.com/token" UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" - // 个人账号 OAuth 凭证(isGcpTos=false,免费 Gemini Code Assist) + // Antigravity OAuth 客户端凭证 ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - // AntigravityOAuthClientSecretEnv 是个人账号 OAuth client_secret 的环境变量名。 + // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" - // 企业账号 OAuth 凭证(isGcpTos=true,Google Cloud / Workspace 用户) - EnterpriseClientID = "884354919052-36trc1jjb3tguiac32ov6cod268c5blh.apps.googleusercontent.com" - - // AntigravityEnterpriseOAuthClientSecretEnv 是企业账号 OAuth client_secret 的环境变量名。 - AntigravityEnterpriseOAuthClientSecretEnv = "ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET" - // 固定的 redirect_uri(用户需手动复制 code) RedirectURI = "http://localhost:8085/callback" - // OAuth scopes(企业和个人共用) + // OAuth scopes Scopes = "https://www.googleapis.com/auth/cloud-platform " + "https://www.googleapis.com/auth/userinfo.email " + "https://www.googleapis.com/auth/userinfo.profile " + @@ -53,18 +46,15 @@ const ( // Antigravity API 端点 antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com" - antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com" + antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) -// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.6(product.json ideVersion) -var defaultUserAgentVersion = "1.20.6" +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5 +var defaultUserAgentVersion = "1.21.9" -// defaultClientSecret 个人账号 client_secret,可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖 +// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 var defaultClientSecret = "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" -// defaultEnterpriseClientSecret 企业账号 client_secret,可通过环境变量 ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET 覆盖 -var defaultEnterpriseClientSecret = "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" - func init() { // 从环境变量读取版本号,未设置则使用默认值 if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" { @@ -74,58 +64,14 @@ func init() { if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" { defaultClientSecret = secret } - if secret := os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv); secret != "" { - defaultEnterpriseClientSecret = secret - } } -// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为) +// GetUserAgent 返回当前配置的 User-Agent func GetUserAgent() string { - return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH) -} - -// ClientCredentials 持有一对 OAuth client_id/secret -type ClientCredentials struct { - ClientID string - ClientSecret string -} - -// GetClientCredentials 根据账号类型返回对应的 OAuth 凭证。 -// isEnterprise=true 时使用企业凭证(isGcpTos=true),否则使用个人凭证。 -func GetClientCredentials(isEnterprise bool) (ClientCredentials, error) { - if isEnterprise { - secret := strings.TrimSpace(os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv)) - if secret == "" { - secret = strings.TrimSpace(defaultEnterpriseClientSecret) - } - if secret == "" { - return ClientCredentials{}, infraerrors.Newf(http.StatusBadRequest, - "ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET_MISSING", - "missing enterprise oauth client_secret; set %s", AntigravityEnterpriseOAuthClientSecretEnv) - } - return ClientCredentials{ClientID: EnterpriseClientID, ClientSecret: secret}, nil - } - secret, err := getClientSecret() - if err != nil { - return ClientCredentials{}, err - } - return ClientCredentials{ClientID: ClientID, ClientSecret: secret}, nil -} - -// BaseURLsForAccount 根据 isGcpTos 返回有序 URL 列表。 -// 企业账号(isGcpTos=true)优先走 prod;个人账号优先走 daily(与真实 IDE 一致)。 -func BaseURLsForAccount(isGcpTos bool) []string { - if isGcpTos { - return []string{antigravityProdBaseURL, antigravityDailyBaseURL} - } - return []string{antigravityDailyBaseURL, antigravityProdBaseURL} + return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion) } func getClientSecret() (string, error) { - if secret := strings.TrimSpace(os.Getenv(AntigravityOAuthClientSecretEnv)); secret != "" { - defaultClientSecret = secret - return secret, nil - } if v := strings.TrimSpace(defaultClientSecret); v != "" { return v, nil } @@ -265,7 +211,6 @@ type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` ProxyURL string `json:"proxy_url,omitempty"` - IsEnterprise bool `json:"is_enterprise,omitempty"` CreatedAt time.Time `json:"created_at"` } @@ -380,15 +325,10 @@ func base64URLEncode(data []byte) string { return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") } -// BuildAuthorizationURL 构建 Google OAuth 授权 URL。 -// isEnterprise=true 时使用企业 client_id;否则使用个人 client_id。 -func BuildAuthorizationURL(state, codeChallenge string, isEnterprise bool) string { - clientID := ClientID - if isEnterprise { - clientID = EnterpriseClientID - } +// BuildAuthorizationURL 构建 Google OAuth 授权 URL +func BuildAuthorizationURL(state, codeChallenge string) string { params := url.Values{} - params.Set("client_id", clientID) + params.Set("client_id", ClientID) params.Set("redirect_uri", RedirectURI) params.Set("response_type", "code") params.Set("scope", Scopes) diff --git a/backend/internal/pkg/antigravity/oauth_runtime_env_test.go b/backend/internal/pkg/antigravity/oauth_runtime_env_test.go deleted file mode 100644 index b84727d7..00000000 --- a/backend/internal/pkg/antigravity/oauth_runtime_env_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package antigravity - -import "testing" - -func TestGetClientSecret_ReadsRuntimeEnvironment(t *testing.T) { - old := defaultClientSecret - defaultClientSecret = "" - t.Cleanup(func() { defaultClientSecret = old }) - - t.Setenv(AntigravityOAuthClientSecretEnv, "runtime-secret") - - secret, err := getClientSecret() - if err != nil { - t.Fatalf("getClientSecret returned error: %v", err) - } - if secret != "runtime-secret" { - t.Fatalf("unexpected secret: got %q want %q", secret, "runtime-secret") - } -} diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index eadc6bba..b5de8166 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -1,14 +1,12 @@ package antigravity import ( - "bytes" "crypto/sha256" "encoding/binary" "encoding/json" "fmt" "log" "math/rand" - "regexp" "strconv" "strings" "sync" @@ -18,16 +16,10 @@ import ( ) var ( - sessionRand = rand.New(rand.NewSource(time.Now().UnixNano())) - sessionRandMutex sync.Mutex - legacyMetadataUserIDSessionPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[a-fA-F0-9-]*_session_([a-fA-F0-9-]{36})$`) - plainSessionIDPattern = regexp.MustCompile(`^(session_)?[a-fA-F0-9-]{36}$`) + sessionRand = rand.New(rand.NewSource(time.Now().UnixNano())) + sessionRandMutex sync.Mutex ) -type claudeMetadataUserIDPayload struct { - SessionID string `json:"session_id"` -} - // generateStableSessionID 基于用户消息内容生成稳定的 session ID func generateStableSessionID(contents []GeminiContent) string { // 查找第一个 user 消息的文本 @@ -47,82 +39,12 @@ func generateStableSessionID(contents []GeminiContent) string { return "-" + strconv.FormatInt(n, 10) } -// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it. -// preferredSessionID wins; otherwise we derive a stable value from the first user turn. -func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) { - var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { - return nil, err - } - - if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" { - return body, nil - } - - sessionID := strings.TrimSpace(preferredSessionID) - if sessionID == "" { - var req GeminiRequest - if err := json.Unmarshal(body, &req); err != nil { - return nil, err - } - sessionID = generateStableSessionID(req.Contents) - } - if sessionID == "" { - return body, nil - } - - payload["sessionId"] = sessionID - return json.Marshal(payload) -} - -func extractSessionIDFromMetadataUserID(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "" - } - - if strings.HasPrefix(raw, "{") { - var payload claudeMetadataUserIDPayload - if err := json.Unmarshal([]byte(raw), &payload); err == nil { - return strings.TrimSpace(payload.SessionID) - } - return "" - } - - if matches := legacyMetadataUserIDSessionPattern.FindStringSubmatch(raw); len(matches) == 2 { - return strings.TrimSpace(matches[1]) - } - - if plainSessionIDPattern.MatchString(raw) { - return raw - } - - return "" -} - -func resolveClaudeRequestSessionID(metadata *ClaudeMetadata, preferredSessionID string, contents []GeminiContent) string { - if metadata != nil { - if sessionID := extractSessionIDFromMetadataUserID(metadata.UserID); sessionID != "" { - return sessionID - } - } - - if sessionID := strings.TrimSpace(preferredSessionID); sessionID != "" { - return sessionID - } - - return generateStableSessionID(contents) -} - type TransformOptions struct { EnableIdentityPatch bool // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; // 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。 IdentityPatch string EnableMCPXML bool - // PreferredSessionID 可选:当 metadata.user_id 不可用于恢复真实会话时, - // 允许调用方显式指定 Antigravity 上游 request.sessionId。 - PreferredSessionID string } func DefaultTransformOptions() TransformOptions { @@ -163,24 +85,12 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为) func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) { - normalizedReq, err := normalizeClaudeRequestForAntigravity(claudeReq) - if err != nil { - return nil, fmt.Errorf("normalize messages: %w", err) - } - // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) // 检测是否有 web_search 工具 - hasWebSearchTool := hasWebSearchTool(normalizedReq.Tools) - // requestType 映射策略: - // - Gemini 模型: "agent"(与 Antigravity 官方客户端一致) - // - Claude 模型: 不设置(避免 Google 后端路由到容量受限的 agent 池,降低 503 率) - // - web_search: "web_search"(触发 Google 搜索增强路由) + hasWebSearchTool := hasWebSearchTool(claudeReq.Tools) requestType := "agent" - if strings.HasPrefix(mappedModel, "claude-") { - requestType = "" // Claude 模型走默认容量池,避免 agent 池 503 - } targetModel := mappedModel if hasWebSearchTool { requestType = "web_search" @@ -190,27 +100,27 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map } // 检测是否启用 thinking - isThinkingEnabled := normalizedReq.Thinking != nil && (normalizedReq.Thinking.Type == "enabled" || normalizedReq.Thinking.Type == "adaptive") + isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(targetModel, "gemini-") // 1. 构建 contents - contents, strippedThinking, err := buildContents(normalizedReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) + contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { return nil, fmt.Errorf("build contents: %w", err) } // 2. 构建 systemInstruction(使用 targetModel 而非原始请求模型,确保身份注入基于最终模型) - systemInstruction := buildSystemInstruction(normalizedReq.System, targetModel, opts, normalizedReq.Tools) + systemInstruction := buildSystemInstruction(claudeReq.System, targetModel, opts, claudeReq.Tools) // 3. 构建 generationConfig - reqForConfig := normalizedReq + reqForConfig := claudeReq if strippedThinking { // If we had to downgrade thinking blocks to plain text due to missing/invalid signatures, // disable upstream thinking mode to avoid signature/structure validation errors. - reqCopy := *normalizedReq + reqCopy := *claudeReq reqCopy.Thinking = nil reqForConfig = &reqCopy } @@ -222,24 +132,19 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map generationConfig := buildGenerationConfig(reqForConfig) // 4. 构建 tools - // 对 Claude / Gemini 模型都保留 functionDeclarations: - // - Claude 分支如果完全丢掉 tools,模型只能看到消息历史中的 tool_use/tool_result, - // 但拿不到当前可用工具定义,容易导致“能还原名字但不会继续发工具调用”。 - // - Gemini 分支原本就依赖 functionDeclarations 触发 function_call。 - isClaudeModel := strings.HasPrefix(targetModel, "claude-") - tools := buildTools(normalizedReq.Tools) + tools := buildTools(claudeReq.Tools) // 5. 构建内部请求 innerRequest := GeminiRequest{ - Contents: contents, - SessionID: resolveClaudeRequestSessionID(normalizedReq.Metadata, opts.PreferredSessionID, contents), - } - - // Gemini 分支保持默认 VALIDATED; - // Claude 分支仅在声明了工具时附带 toolConfig,避免再把工具能力静默丢失。 - defaultValidated := !isClaudeModel || len(tools) > 0 - if toolConfig := buildToolConfig(normalizedReq.ToolChoice, defaultValidated); toolConfig != nil { - innerRequest.ToolConfig = toolConfig + Contents: contents, + // 总是设置 toolConfig,与官方客户端一致 + ToolConfig: &GeminiToolConfig{ + FunctionCallingConfig: &GeminiFunctionCallingConfig{ + Mode: "VALIDATED", + }, + }, + // 总是生成 sessionId,基于用户消息内容 + SessionID: generateStableSessionID(contents), } if systemInstruction != nil { @@ -252,6 +157,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map innerRequest.Tools = tools } + // 如果提供了 metadata.user_id,优先使用 + if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" { + innerRequest.SessionID = claudeReq.Metadata.UserID + } + // 6. 包装为 v1internal 请求 v1Req := V1InternalRequest{ Project: projectID, @@ -265,319 +175,6 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map return json.Marshal(v1Req) } -const ( - maxAntigravityToolDescriptionChars = 400 - maxAntigravitySchemaDescriptionChars = 200 - maxAntigravityToolResultChars = 200000 -) - -func normalizeClaudeRequestForAntigravity(claudeReq *ClaudeRequest) (*ClaudeRequest, error) { - if claudeReq == nil { - return nil, nil - } - - reqCopy := *claudeReq - if len(claudeReq.Messages) == 0 { - return &reqCopy, nil - } - - normalizedMessages, err := normalizeClaudeMessagesForAntigravity(claudeReq.Messages) - if err != nil { - return nil, err - } - reqCopy.Messages = normalizedMessages - return &reqCopy, nil -} - -func normalizeClaudeMessagesForAntigravity(messages []ClaudeMessage) ([]ClaudeMessage, error) { - normalized := make([]ClaudeMessage, 0, len(messages)+1) - pendingToolUseIDs := make([]string, 0) - - for _, message := range messages { - blocks, hasBlocks := parseClaudeMessageBlocks(message.Content) - - switch message.Role { - case "assistant": - if len(pendingToolUseIDs) > 0 { - synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs) - if err != nil { - return nil, err - } - normalized = append(normalized, synthetic) - pendingToolUseIDs = pendingToolUseIDs[:0] - } - - if !hasBlocks { - normalized = append(normalized, cloneClaudeMessage(message)) - continue - } - - stripped := stripNonToolPartsAfterToolUse(reorderAssistantThinkingBlocks(blocks)) - pendingToolUseIDs = append(pendingToolUseIDs, collectToolUseIDs(stripped)...) - - nextMessage, err := buildClaudeMessageWithBlocks(message.Role, stripped) - if err != nil { - return nil, err - } - normalized = append(normalized, nextMessage) - - case "user": - if !hasBlocks { - if len(pendingToolUseIDs) > 0 { - synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs) - if err != nil { - return nil, err - } - normalized = append(normalized, synthetic) - pendingToolUseIDs = pendingToolUseIDs[:0] - } - normalized = append(normalized, cloneClaudeMessage(message)) - continue - } - - parts := cloneJSONBlocks(blocks) - if len(pendingToolUseIDs) > 0 { - toolResults, nonToolResults := partitionToolResultBlocks(parts) - existingIDs := collectToolResultIDs(toolResults) - missingIDs := diffStringSlice(pendingToolUseIDs, existingIDs) - if len(missingIDs) > 0 { - parts = append(append(toolResults, buildSyntheticToolResultBlocks(missingIDs)...), nonToolResults...) - } - pendingToolUseIDs = pendingToolUseIDs[:0] - } - - toolResults, nonToolResults := partitionToolResultBlocks(parts) - switch { - case len(toolResults) == 0: - nextMessage, err := buildClaudeMessageWithBlocks(message.Role, parts) - if err != nil { - return nil, err - } - normalized = append(normalized, nextMessage) - case len(nonToolResults) == 0: - nextMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults) - if err != nil { - return nil, err - } - normalized = append(normalized, nextMessage) - default: - toolResultMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults) - if err != nil { - return nil, err - } - userTextMessage, err := buildClaudeMessageWithBlocks(message.Role, nonToolResults) - if err != nil { - return nil, err - } - normalized = append(normalized, toolResultMessage, userTextMessage) - } - - default: - normalized = append(normalized, cloneClaudeMessage(message)) - } - } - - if len(pendingToolUseIDs) > 0 { - synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs) - if err != nil { - return nil, err - } - normalized = append(normalized, synthetic) - } - - return normalized, nil -} - -func parseClaudeMessageBlocks(content json.RawMessage) ([]map[string]any, bool) { - var blocks []map[string]any - if err := json.Unmarshal(content, &blocks); err != nil { - return nil, false - } - return blocks, true -} - -func cloneClaudeMessage(message ClaudeMessage) ClaudeMessage { - cloned := ClaudeMessage{Role: message.Role} - if len(message.Content) > 0 { - cloned.Content = append(json.RawMessage(nil), message.Content...) - } - return cloned -} - -func cloneJSONBlocks(blocks []map[string]any) []map[string]any { - cloned := make([]map[string]any, 0, len(blocks)) - for _, block := range blocks { - cloned = append(cloned, cloneJSONMap(block)) - } - return cloned -} - -func cloneJSONMap(block map[string]any) map[string]any { - if block == nil { - return nil - } - if cloned, ok := deepCopy(block).(map[string]any); ok { - return cloned - } - fallback := make(map[string]any, len(block)) - for key, value := range block { - fallback[key] = value - } - return fallback -} - -func buildClaudeMessageWithBlocks(role string, blocks []map[string]any) (ClaudeMessage, error) { - payload, err := json.Marshal(blocks) - if err != nil { - return ClaudeMessage{}, fmt.Errorf("marshal %s message blocks: %w", role, err) - } - return ClaudeMessage{Role: role, Content: payload}, nil -} - -func buildSyntheticToolResultMessage(toolUseIDs []string) (ClaudeMessage, error) { - return buildClaudeMessageWithBlocks("user", buildSyntheticToolResultBlocks(toolUseIDs)) -} - -func buildSyntheticToolResultBlocks(toolUseIDs []string) []map[string]any { - blocks := make([]map[string]any, 0, len(toolUseIDs)) - for _, toolUseID := range toolUseIDs { - if strings.TrimSpace(toolUseID) == "" { - continue - } - blocks = append(blocks, map[string]any{ - "type": "tool_result", - "tool_use_id": toolUseID, - "is_error": true, - "content": []map[string]any{ - { - "type": "text", - "text": "[tool_result missing; tool execution interrupted]", - }, - }, - }) - } - return blocks -} - -func reorderAssistantThinkingBlocks(blocks []map[string]any) []map[string]any { - thinkingBlocks := make([]map[string]any, 0) - otherBlocks := make([]map[string]any, 0, len(blocks)) - - for _, block := range blocks { - cloned := cloneJSONMap(block) - blockType, _ := cloned["type"].(string) - if blockType == "thinking" || blockType == "redacted_thinking" { - delete(cloned, "cache_control") - thinkingBlocks = append(thinkingBlocks, cloned) - continue - } - otherBlocks = append(otherBlocks, cloned) - } - - if len(thinkingBlocks) == 0 { - return otherBlocks - } - return append(thinkingBlocks, otherBlocks...) -} - -func stripNonToolPartsAfterToolUse(blocks []map[string]any) []map[string]any { - cleaned := make([]map[string]any, 0, len(blocks)) - seenToolUse := false - - for _, block := range blocks { - blockType, _ := block["type"].(string) - if blockType == "tool_use" { - seenToolUse = true - cleaned = append(cleaned, block) - continue - } - if !seenToolUse { - cleaned = append(cleaned, block) - continue - } - if isIgnorableTrailingTextBlock(block) { - continue - } - } - - return cleaned -} - -func isIgnorableTrailingTextBlock(block map[string]any) bool { - blockType, _ := block["type"].(string) - if blockType != "text" { - return false - } - text, _ := block["text"].(string) - trimmed := strings.TrimSpace(text) - return trimmed == "" || trimmed == "(no content)" -} - -func collectToolUseIDs(blocks []map[string]any) []string { - ids := make([]string, 0) - for _, block := range blocks { - blockType, _ := block["type"].(string) - if blockType != "tool_use" { - continue - } - id, _ := block["id"].(string) - if strings.TrimSpace(id) != "" { - ids = append(ids, id) - } - } - return ids -} - -func collectToolResultIDs(blocks []map[string]any) []string { - ids := make([]string, 0, len(blocks)) - for _, block := range blocks { - id, _ := block["tool_use_id"].(string) - if strings.TrimSpace(id) != "" { - ids = append(ids, id) - } - } - return ids -} - -func diffStringSlice(left, right []string) []string { - if len(left) == 0 { - return nil - } - seen := make(map[string]struct{}, len(right)) - for _, value := range right { - if strings.TrimSpace(value) != "" { - seen[value] = struct{}{} - } - } - - diff := make([]string, 0, len(left)) - for _, value := range left { - value = strings.TrimSpace(value) - if value == "" { - continue - } - if _, ok := seen[value]; ok { - continue - } - diff = append(diff, value) - } - return diff -} - -func partitionToolResultBlocks(blocks []map[string]any) (toolResults []map[string]any, nonToolResults []map[string]any) { - toolResults = make([]map[string]any, 0) - nonToolResults = make([]map[string]any, 0) - for _, block := range blocks { - blockType, _ := block["type"].(string) - if blockType == "tool_result" { - toolResults = append(toolResults, block) - continue - } - nonToolResults = append(nonToolResults, block) - } - return toolResults, nonToolResults -} - // antigravityIdentity Antigravity identity 提示词 const antigravityIdentity = ` You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding. @@ -877,14 +474,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu case "tool_use": // 存储 id -> name 映射 - toolName := normalizeClaudeCodeToolName(block.Name) - if block.ID != "" && toolName != "" { - toolIDToName[block.ID] = toolName + if block.ID != "" && block.Name != "" { + toolIDToName[block.ID] = block.Name } part := GeminiPart{ FunctionCall: &GeminiFunctionCall{ - Name: toolName, + Name: block.Name, Args: block.Input, ID: block.ID, }, @@ -901,12 +497,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu case "tool_result": // 获取函数名 - funcName := normalizeClaudeCodeToolName(block.Name) + funcName := block.Name if funcName == "" { if name, ok := toolIDToName[block.ToolUseID]; ok { funcName = name } else { - funcName = normalizeClaudeCodeToolName(block.ToolUseID) + funcName = block.ToolUseID } } @@ -929,84 +525,47 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } // parseToolResultContent 解析 tool_result 的 content -func parseToolResultContent(content json.RawMessage, isError bool) any { +func parseToolResultContent(content json.RawMessage, isError bool) string { if len(content) == 0 { - return defaultToolResultContent(isError) + if isError { + return "Tool execution failed with no output." + } + return "Command executed successfully." } // 尝试解析为字符串 var str string if err := json.Unmarshal(content, &str); err == nil { if strings.TrimSpace(str) == "" { - return defaultToolResultContent(isError) + if isError { + return "Tool execution failed with no output." + } + return "Command executed successfully." } - return truncateInlineText(str, maxAntigravityToolResultChars) + return str } - // 优先保留结构化 tool_result,避免上游把内容视为无效的纯文本降级。 + // 尝试解析为数组 var arr []map[string]any if err := json.Unmarshal(content, &arr); err == nil { - sanitized := sanitizeToolResultBlocksForAntigravity(arr) - if len(sanitized) == 0 { - return defaultToolResultContent(isError) + var texts []string + for _, item := range arr { + if text, ok := item["text"].(string); ok { + texts = append(texts, text) + } } - return sanitized - } - - var obj map[string]any - if err := json.Unmarshal(content, &obj); err == nil { - sanitized := sanitizeToolResultObjectForAntigravity(obj) - if len(sanitized) == 0 { - return defaultToolResultContent(isError) + result := strings.Join(texts, "\n") + if strings.TrimSpace(result) == "" { + if isError { + return "Tool execution failed with no output." + } + return "Command executed successfully." } - return sanitized + return result } // 返回原始 JSON - return truncateInlineText(string(content), maxAntigravityToolResultChars) -} - -func defaultToolResultContent(isError bool) string { - if isError { - return "Tool execution failed with no output." - } - return "Command executed successfully." -} - -func sanitizeToolResultBlocksForAntigravity(blocks []map[string]any) []map[string]any { - sanitized := make([]map[string]any, 0, len(blocks)) - for _, block := range blocks { - if isBase64ImageToolResultBlock(block) { - continue - } - cloned := cloneJSONMap(block) - if text, ok := cloned["text"].(string); ok { - cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars) - } - sanitized = append(sanitized, cloned) - } - return sanitized -} - -func sanitizeToolResultObjectForAntigravity(block map[string]any) map[string]any { - if isBase64ImageToolResultBlock(block) { - return nil - } - cloned := cloneJSONMap(block) - if text, ok := cloned["text"].(string); ok { - cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars) - } - return cloned -} - -func isBase64ImageToolResultBlock(block map[string]any) bool { - blockType, _ := block["type"].(string) - if blockType != "image" { - return false - } - source, _ := block["source"].(map[string]any) - sourceType, _ := source["type"].(string) - return sourceType == "base64" + return string(content) } // buildGenerationConfig 构建 generationConfig @@ -1074,15 +633,6 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } } config.ThinkingConfig.ThinkingBudget = budget - } else if strings.HasSuffix(req.Model, "-thinking") || strings.HasPrefix(req.Model, "claude-sonnet-4-6") { - // 自动注入 thinkingConfig 的两种情形(客户端未显式开启 thinking): - // 1. 模型名以 -thinking 结尾(如 claude-opus-4-6-thinking):Google 要求此后缀模型必须携带 thinkingConfig。 - // 2. claude-sonnet-4-6:无 -thinking 变体(404),但模型本身要求携带 thinkingConfig;budget 必须为 -1(动态)。 - // 注:固定 budget(如 1024)在 max_tokens 较小时会触发 400(max_tokens 必须大于 budget)。 - config.ThinkingConfig = &GeminiThinkingConfig{ - IncludeThoughts: true, - ThinkingBudget: -1, // 动态预算,避免 max_tokens vs budget 冲突 - } } if config.MaxOutputTokens > maxLimit { @@ -1126,65 +676,6 @@ func isWebSearchTool(tool ClaudeTool) bool { } } -func buildToolConfig(toolChoice json.RawMessage, defaultValidated bool) *GeminiToolConfig { - raw := bytes.TrimSpace(toolChoice) - if len(raw) == 0 { - if !defaultValidated { - return nil - } - return &GeminiToolConfig{ - FunctionCallingConfig: &GeminiFunctionCallingConfig{ - Mode: "VALIDATED", - }, - } - } - - choiceType := "" - toolName := "" - - if len(raw) > 0 && raw[0] == '"' { - var choice string - if err := json.Unmarshal(raw, &choice); err == nil { - choiceType = strings.TrimSpace(choice) - } - } else { - var choice map[string]any - if err := json.Unmarshal(raw, &choice); err == nil { - if value, ok := choice["type"].(string); ok { - choiceType = strings.TrimSpace(value) - } - if value, ok := choice["name"].(string); ok { - toolName = normalizeClaudeCodeToolName(value) - } - } - } - - mode := "" - switch strings.ToLower(choiceType) { - case "auto": - mode = "AUTO" - case "none": - mode = "NONE" - case "any", "required": - mode = "ANY" - case "tool": - mode = "ANY" - case "validated": - mode = "VALIDATED" - default: - if !defaultValidated { - return nil - } - mode = "VALIDATED" - } - - cfg := &GeminiFunctionCallingConfig{Mode: mode} - if toolName != "" && mode == "ANY" { - cfg.AllowedFunctionNames = []string{toolName} - } - return &GeminiToolConfig{FunctionCallingConfig: cfg} -} - // buildTools 构建 tools func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { if len(tools) == 0 { @@ -1215,12 +706,12 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { continue } description = tool.Custom.Description - inputSchema = cloneStringAnyMap(tool.Custom.InputSchema) + inputSchema = tool.Custom.InputSchema } else { // 标准格式: 从顶层字段获取 description = tool.Description - inputSchema = cloneStringAnyMap(tool.InputSchema) + inputSchema = tool.InputSchema } // 清理 JSON Schema @@ -1235,11 +726,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { "properties": map[string]any{}, } } - description = compactToolDescriptionForAntigravity(description) - params = compactSchemaDescriptionsForAntigravity(params) funcDecls = append(funcDecls, GeminiFunctionDecl{ - Name: normalizeClaudeCodeToolName(tool.Name), + Name: tool.Name, Description: description, Parameters: params, }) @@ -1268,64 +757,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { return declarations } - -func cloneStringAnyMap(input map[string]any) map[string]any { - if input == nil { - return nil - } - if cloned, ok := deepCopy(input).(map[string]any); ok { - return cloned - } - fallback := make(map[string]any, len(input)) - for key, value := range input { - fallback[key] = value - } - return fallback -} - -func compactToolDescriptionForAntigravity(description string) string { - if strings.TrimSpace(description) == "" { - return "" - } - lines := strings.Split(strings.ReplaceAll(description, "\r\n", "\n"), "\n") - compacted := make([]string, 0, len(lines)) - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - compacted = append(compacted, line) - if len(compacted) == 6 { - break - } - } - return truncateInlineText(strings.Join(compacted, " "), maxAntigravityToolDescriptionChars) -} - -func compactSchemaDescriptionsForAntigravity(schema map[string]any) map[string]any { - for key, value := range schema { - switch typed := value.(type) { - case string: - if key == "description" { - schema[key] = truncateInlineText(strings.Join(strings.Fields(typed), " "), maxAntigravitySchemaDescriptionChars) - } - case map[string]any: - schema[key] = compactSchemaDescriptionsForAntigravity(typed) - case []any: - for i, item := range typed { - if nested, ok := item.(map[string]any); ok { - typed[i] = compactSchemaDescriptionsForAntigravity(nested) - } - } - schema[key] = typed - } - } - return schema -} - -func truncateInlineText(text string, maxChars int) string { - if maxChars <= 0 || len(text) <= maxChars { - return text - } - return text[:maxChars] + "...[truncated " + strconv.Itoa(len(text)-maxChars) + " chars]" -} diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index f5e01379..6fae5b7c 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -8,112 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestEnsureGeminiRequestSessionID(t *testing.T) { - t.Run("prefers provided session id", func(t *testing.T) { - body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) - updated, err := EnsureGeminiRequestSessionID(body, "session-from-header") - require.NoError(t, err) - - var payload map[string]any - require.NoError(t, json.Unmarshal(updated, &payload)) - require.Equal(t, "session-from-header", payload["sessionId"]) - }) - - t.Run("keeps existing session id", func(t *testing.T) { - body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) - updated, err := EnsureGeminiRequestSessionID(body, "session-from-header") - require.NoError(t, err) - - var payload map[string]any - require.NoError(t, json.Unmarshal(updated, &payload)) - require.Equal(t, "session-in-body", payload["sessionId"]) - }) - - t.Run("derives stable fallback from contents", func(t *testing.T) { - body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`) - first, err := EnsureGeminiRequestSessionID(body, "") - require.NoError(t, err) - second, err := EnsureGeminiRequestSessionID(body, "") - require.NoError(t, err) - - var firstPayload map[string]any - var secondPayload map[string]any - require.NoError(t, json.Unmarshal(first, &firstPayload)) - require.NoError(t, json.Unmarshal(second, &secondPayload)) - require.NotEmpty(t, firstPayload["sessionId"]) - require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"]) - }) -} - -func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDJSON(t *testing.T) { - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - Messages: []ClaudeMessage{ - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), - }, - }, - Metadata: &ClaudeMetadata{ - UserID: `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`, - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions()) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", req.Request.SessionID) -} - -func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDLegacy(t *testing.T) { - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - Messages: []ClaudeMessage{ - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), - }, - }, - Metadata: &ClaudeMetadata{ - UserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000", - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions()) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", req.Request.SessionID) -} - -func TestTransformClaudeToGeminiWithOptions_PrefersExplicitSessionWhenMetadataIsNotSessionPayload(t *testing.T) { - opts := DefaultTransformOptions() - opts.PreferredSessionID = "session-header-1" - - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - Messages: []ClaudeMessage{ - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), - }, - }, - Metadata: &ClaudeMetadata{ - UserID: "custom-user-42", - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", opts) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.Equal(t, "session-header-1", req.Request.SessionID) -} - // TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { tests := []struct { @@ -436,36 +330,16 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { wantPresent: true, }, { - // Google v1internal 要求 -thinking 模型必须携带 thinkingConfig,即使客户端明确 disabled。 - // 不携带会导致 Google 立即返回错误(在生产中表现为快速 503)。 - name: "disabled on -thinking model auto-injects thinkingConfig (Google requires it)", + name: "disabled does not emit thinkingConfig", model: "claude-opus-4-6-thinking", thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024}, - wantBudget: -1, // auto-injected dynamic budget - wantPresent: true, + wantBudget: 0, + wantPresent: false, }, { - // Google v1internal 要求 -thinking 模型必须携带 thinkingConfig,nil 时自动注入。 - name: "nil thinking on -thinking model auto-injects thinkingConfig (Google requires it)", + name: "nil thinking does not emit thinkingConfig", model: "claude-opus-4-6-thinking", thinking: nil, - wantBudget: -1, // auto-injected dynamic budget - wantPresent: true, - }, - { - // claude-sonnet-4-6 需要 thinkingConfig(无 -thinking 变体),budget 必须为 -1(动态) - // 经测试:claude-sonnet-4-6-thinking → 404;claude-sonnet-4-6 + budget=-1 → 200 OK - name: "nil thinking on claude-sonnet-4-6 auto-injects thinkingConfig (no -thinking variant exists)", - model: "claude-sonnet-4-6", - thinking: nil, - wantBudget: -1, - wantPresent: true, - }, - { - // 非 -thinking 普通模型(如 claude-opus-4-6,服务层已转为 -thinking,此处测试原始名) - name: "nil thinking on plain non-thinking model does not emit thinkingConfig", - model: "claude-opus-4-6", - thinking: nil, wantBudget: 0, wantPresent: false, }, @@ -582,214 +456,3 @@ func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name) require.NotNil(t, req.Request.Tools[1].GoogleSearch) } - -func TestTransformClaudeToGeminiWithOptions_ClaudeModelKeepsToolsAndValidatedToolConfig(t *testing.T) { - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - Messages: []ClaudeMessage{ - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"read the file"}]`), - }, - }, - Tools: []ClaudeTool{ - { - Name: "read_file", - Description: "Read a file", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "file_path": map[string]any{"type": "string"}, - }, - }, - }, - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions()) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.Len(t, req.Request.Tools, 1) - require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1) - require.Equal(t, "Read", req.Request.Tools[0].FunctionDeclarations[0].Name) - require.NotNil(t, req.Request.ToolConfig) - require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig) - require.Equal(t, "VALIDATED", req.Request.ToolConfig.FunctionCallingConfig.Mode) -} - -func TestTransformClaudeToGeminiWithOptions_ClaudeModelToolChoiceSpecificTool(t *testing.T) { - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - ToolChoice: json.RawMessage(`{"type":"tool","name":"search_files"}`), - Messages: []ClaudeMessage{ - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"find todo"}]`), - }, - }, - Tools: []ClaudeTool{ - { - Name: "search_files", - Description: "Search files", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "pattern": map[string]any{"type": "string"}, - }, - }, - }, - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions()) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.NotNil(t, req.Request.ToolConfig) - require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig) - require.Equal(t, "ANY", req.Request.ToolConfig.FunctionCallingConfig.Mode) - require.Equal(t, []string{"Grep"}, req.Request.ToolConfig.FunctionCallingConfig.AllowedFunctionNames) -} - -func TestTransformClaudeToGeminiWithOptions_NormalizesInterruptedToolHistory(t *testing.T) { - claudeReq := &ClaudeRequest{ - Model: "claude-sonnet-4-5", - Messages: []ClaudeMessage{ - { - Role: "assistant", - Content: json.RawMessage(`[ - {"type":"tool_use","id":"tool-1","name":"Bash","input":{"command":"pwd"}}, - {"type":"text","text":"(no content)"} - ]`), - }, - { - Role: "user", - Content: json.RawMessage(`[{"type":"text","text":"继续"}]`), - }, - }, - } - - body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions()) - require.NoError(t, err) - - var req V1InternalRequest - require.NoError(t, json.Unmarshal(body, &req)) - require.Len(t, req.Request.Contents, 3) - - first := req.Request.Contents[0] - require.Equal(t, "model", first.Role) - require.Len(t, first.Parts, 1) - require.NotNil(t, first.Parts[0].FunctionCall) - require.Equal(t, "tool-1", first.Parts[0].FunctionCall.ID) - - second := req.Request.Contents[1] - require.Equal(t, "user", second.Role) - require.Len(t, second.Parts, 1) - require.NotNil(t, second.Parts[0].FunctionResponse) - require.Equal(t, "tool-1", second.Parts[0].FunctionResponse.ID) - resultBlocks, ok := second.Parts[0].FunctionResponse.Response["result"].([]any) - require.True(t, ok) - require.Len(t, resultBlocks, 1) - resultBlock, ok := resultBlocks[0].(map[string]any) - require.True(t, ok) - require.Equal(t, "text", resultBlock["type"]) - require.Equal(t, "[tool_result missing; tool execution interrupted]", resultBlock["text"]) - - third := req.Request.Contents[2] - require.Equal(t, "user", third.Role) - require.Len(t, third.Parts, 1) - require.Equal(t, "继续", third.Parts[0].Text) -} - -func TestNormalizeClaudeMessagesForAntigravity_ReordersThinkingAndSplitsToolResult(t *testing.T) { - messages := []ClaudeMessage{ - { - Role: "assistant", - Content: json.RawMessage(`[ - {"type":"text","text":"before"}, - {"type":"thinking","thinking":"deep thought","signature":"sig-1"}, - {"type":"tool_use","id":"tool-2","name":"Bash","input":{"command":"ls"}}, - {"type":"text","text":"(no content)"} - ]`), - }, - { - Role: "user", - Content: json.RawMessage(`[ - {"type":"tool_result","tool_use_id":"tool-2","content":[{"type":"text","text":"ok"}]}, - {"type":"text","text":"下一步"} - ]`), - }, - } - - normalized, err := normalizeClaudeMessagesForAntigravity(messages) - require.NoError(t, err) - require.Len(t, normalized, 3) - - var assistantBlocks []map[string]any - require.NoError(t, json.Unmarshal(normalized[0].Content, &assistantBlocks)) - require.Len(t, assistantBlocks, 3) - require.Equal(t, "thinking", assistantBlocks[0]["type"]) - require.Equal(t, "text", assistantBlocks[1]["type"]) - require.Equal(t, "tool_use", assistantBlocks[2]["type"]) - - var toolResultBlocks []map[string]any - require.NoError(t, json.Unmarshal(normalized[1].Content, &toolResultBlocks)) - require.Len(t, toolResultBlocks, 1) - require.Equal(t, "tool_result", toolResultBlocks[0]["type"]) - - var userTextBlocks []map[string]any - require.NoError(t, json.Unmarshal(normalized[2].Content, &userTextBlocks)) - require.Len(t, userTextBlocks, 1) - require.Equal(t, "text", userTextBlocks[0]["type"]) - require.Equal(t, "下一步", userTextBlocks[0]["text"]) -} - -func TestParseToolResultContent_PreservesStructuredBlocks(t *testing.T) { - content := json.RawMessage(`[ - {"type":"text","text":"hello"}, - {"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}} - ]`) - - result := parseToolResultContent(content, false) - blocks, ok := result.([]map[string]any) - require.True(t, ok) - require.Len(t, blocks, 1) - require.Equal(t, "text", blocks[0]["type"]) - require.Equal(t, "hello", blocks[0]["text"]) -} - -func TestBuildTools_CompactsDescriptions(t *testing.T) { - longLine := strings.Repeat("schema detail ", 40) - result := buildTools([]ClaudeTool{ - { - Name: "describe", - Description: strings.Repeat("tool description\n", 20), - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "query": map[string]any{ - "type": "string", - "description": longLine, - }, - }, - }, - }, - }) - - require.Len(t, result, 1) - require.Len(t, result[0].FunctionDeclarations, 1) - - decl := result[0].FunctionDeclarations[0] - require.LessOrEqual(t, len(decl.Description), maxAntigravityToolDescriptionChars+32) - - props, ok := decl.Parameters["properties"].(map[string]any) - require.True(t, ok) - query, ok := props["query"].(map[string]any) - require.True(t, ok) - description, ok := query["description"].(string) - require.True(t, ok) - require.LessOrEqual(t, len(description), maxAntigravitySchemaDescriptionChars+32) -} diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 0688d7f9..bc1fd32e 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -121,20 +121,17 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { p.hasToolCall = true - toolName := normalizeClaudeCodeToolName(part.FunctionCall.Name) - // 生成 tool_use id toolID := part.FunctionCall.ID if toolID == "" { - toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID()) + toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID()) } item := ClaudeContentItem{ - Type: "tool_use", - ID: toolID, - Name: toolName, - Input: part.FunctionCall.Args, - Caller: &ToolCaller{Type: "direct"}, + Type: "tool_use", + ID: toolID, + Name: part.FunctionCall.Name, + Input: part.FunctionCall.Args, } if signature != "" { diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 8dec839c..4a68f3a9 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -362,21 +362,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu var result bytes.Buffer p.usedTool = true - toolName := normalizeClaudeCodeToolName(fc.Name) toolID := fc.ID if toolID == "" { - toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID()) + toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID()) } toolUse := map[string]any{ "type": "tool_use", "id": toolID, - "name": toolName, + "name": fc.Name, "input": map[string]any{}, - "caller": map[string]any{ - "type": "direct", - }, } if signature != "" { diff --git a/backend/internal/pkg/windsurf/discovery_test.go b/backend/internal/pkg/windsurf/discovery_test.go index 41d257ed..77c58a90 100644 --- a/backend/internal/pkg/windsurf/discovery_test.go +++ b/backend/internal/pkg/windsurf/discovery_test.go @@ -27,7 +27,7 @@ func stubUserHome(t *testing.T, home string) { func TestDiscoverBinary_EnvOverrideWins(t *testing.T) { stubStatFn(t, map[string]bool{ - "/tmp/my-ls": true, + "/tmp/my-ls": true, "/opt/windsurf/language_server_linux_x64": true, // should not be picked }) got, err := discoverBinaryFor(Platform{"linux", "amd64"}, "/tmp/my-ls", "/opt/windsurf/language_server_linux_x64") diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 22971cef..023e40bb 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -39,8 +39,6 @@ func ProvideRouter( opsService *service.OpsService, settingService *service.SettingService, redisClient *redis.Client, - langServerService *service.LanguageServerService, - lsrpcHandler *service.LSRPCHandler, ) *gin.Engine { if cfg.Server.Mode == "release" { gin.SetMode(gin.ReleaseMode) @@ -97,7 +95,7 @@ func ProvideRouter( service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) }) - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index f457062c..79eef637 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -7,7 +7,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect" "github.com/Wei-Shaw/sub2api/internal/handler" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/routes" @@ -33,8 +32,6 @@ func SetupRouter( settingService *service.SettingService, cfg *config.Config, redisClient *redis.Client, - langServerService *service.LanguageServerService, - lsrpcHandler *service.LSRPCHandler, ) *gin.Engine { // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src var cachedFrameOrigins atomic.Pointer[[]string] @@ -84,7 +81,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return r } @@ -102,8 +99,6 @@ func registerRoutes( settingService *service.SettingService, cfg *config.Config, redisClient *redis.Client, - langServerService *service.LanguageServerService, - lsrpcHandler *service.LSRPCHandler, ) { // 通用路由(健康检查、状态等) routes.RegisterCommonRoutes(r) @@ -120,15 +115,5 @@ func registerRoutes( // Windsurf gateway routes routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) - // 注册 Antigravity HTTP API 路由 - routes.RegisterAntigravityHTTPRoutes(v1, langServerService) - - // 挂载 connectrpc LanguageServerService 路由 - // Claude Code 客户端通过 /exa.language_server_pb.LanguageServerService/* 路径访问 - if lsrpcHandler != nil { - lsrpcPath, lsrpcHTTPHandler := language_server_pbconnect.NewLanguageServerServiceHandler(lsrpcHandler) - r.Any(lsrpcPath+"*action", gin.WrapH(lsrpcHTTPHandler)) - } - routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService) } diff --git a/backend/internal/server/routes/antigravity_http.go b/backend/internal/server/routes/antigravity_http.go deleted file mode 100644 index f25fda21..00000000 --- a/backend/internal/server/routes/antigravity_http.go +++ /dev/null @@ -1,192 +0,0 @@ -package routes - -import ( - "encoding/json" - "log/slog" - "net/http" - - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" -) - -// RegisterAntigravityHTTPRoutes 注册 Antigravity HTTP API 路由 -func RegisterAntigravityHTTPRoutes(v1 *gin.RouterGroup, langServerService *service.LanguageServerService) { - logger := slog.Default() - - // 创建处理器 - cascadeGroup := v1.Group("/cascade") - { - // 启动 Cascade 会话 - cascadeGroup.POST("/start", func(c *gin.Context) { - handleStartCascade(c, langServerService, logger) - }) - - // 发送消息到 Cascade(流式响应) - cascadeGroup.POST("/message", func(c *gin.Context) { - handleSendMessage(c, langServerService, logger) - }) - - // 取消 Cascade 会话 - cascadeGroup.POST("/cancel", func(c *gin.Context) { - handleCancelCascade(c, langServerService, logger) - }) - } - - // 模型列表 - v1.GET("/models", func(c *gin.Context) { - handleGetModels(c, langServerService, logger) - }) - - // 健康检查 - v1.GET("/health", func(c *gin.Context) { - handleHealth(c, logger) - }) -} - -// handleStartCascade 处理启动 Cascade 请求 -func handleStartCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) { - type StartCascadeRequest struct { - Model string `json:"model" binding:"required"` - SystemPrompt string `json:"system_prompt"` - Metadata map[string]string `json:"metadata"` - } - - var req StartCascadeRequest - if err := c.ShouldBindJSON(&req); err != nil { - logger.Error("invalid start cascade request", "error", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - // 获取 OAuth token - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"}) - return - } - - // 调用服务 - cascadeID, err := svc.StartCascade( - c.Request.Context(), - req.Model, - req.SystemPrompt, - req.Metadata, - token, - ) - if err != nil { - logger.Error("start cascade failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"cascade_id": cascadeID}) -} - -// handleSendMessage 处理发送消息请求(流式) -func handleSendMessage(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) { - type SendMessageRequest struct { - CascadeID string `json:"cascade_id" binding:"required"` - Message string `json:"message" binding:"required"` - } - - var req SendMessageRequest - if err := c.ShouldBindJSON(&req); err != nil { - logger.Error("invalid send message request", "error", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - // 获取 OAuth token - token := c.GetHeader("Authorization") - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"}) - return - } - - // 调用服务并获取流式更新通道 - updateChan, err := svc.SendUserMessage(c.Request.Context(), req.CascadeID, req.Message, token) - if err != nil { - logger.Error("send message failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // 设置 SSE 响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - c.Status(http.StatusOK) - - // 流式发送更新到客户端 - flusher, ok := c.Writer.(http.Flusher) - if !ok { - logger.Error("response writer does not support flushing") - return - } - - for event := range updateChan { - if event == nil { - break - } - - // 将事件序列化为 JSON - eventJSON, err := marshalJSON(event) - if err != nil { - logger.Error("failed to marshal event", "error", err) - continue - } - - // 发送 SSE 格式的数据 - _, _ = c.Writer.WriteString("data: " + string(eventJSON) + "\n\n") - flusher.Flush() - } -} - -// handleCancelCascade 处理取消 Cascade 请求 -func handleCancelCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) { - type CancelRequest struct { - CascadeID string `json:"cascade_id" binding:"required"` - } - - var req CancelRequest - if err := c.ShouldBindJSON(&req); err != nil { - logger.Error("invalid cancel request", "error", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - err := svc.CancelCascade(c.Request.Context(), req.CascadeID) - if err != nil { - logger.Error("cancel cascade failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "cascade cancelled"}) -} - -// handleGetModels 处理获取模型列表请求 -func handleGetModels(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) { - models, err := svc.GetAvailableModels(c.Request.Context()) - if err != nil { - logger.Error("get models failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "models": models, - "default_model": "claude-opus-4-6", - }) -} - -// handleHealth 处理健康检查请求 -func handleHealth(c *gin.Context, logger *slog.Logger) { - c.JSON(http.StatusOK, gin.H{"status": "healthy"}) -} - -// marshalJSON 辅助函数用于序列化事件 -func marshalJSON(v interface{}) ([]byte, error) { - return json.Marshal(v) -} diff --git a/backend/internal/server/routes/antigravity_http_test.go b/backend/internal/server/routes/antigravity_http_test.go deleted file mode 100644 index 636f22f9..00000000 --- a/backend/internal/server/routes/antigravity_http_test.go +++ /dev/null @@ -1,365 +0,0 @@ -package routes - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" - "log/slog" -) - -func TestAntigravityHTTPRoutes(t *testing.T) { - gin.SetMode(gin.TestMode) - - // 创建模拟的 LanguageServerService - mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) - defer mockService.Stop() - - // 创建路由 - r := gin.New() - v1 := r.Group("/api/v1") - - // 注册 Antigravity 路由 - RegisterAntigravityHTTPRoutes(v1, mockService) - - // 测试 1: GET /health - t.Run("HealthCheck", func(t *testing.T) { - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/api/v1/health", nil) - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - - var result map[string]string - json.Unmarshal(w.Body.Bytes(), &result) - if result["status"] != "healthy" { - t.Fatalf("Expected status=healthy, got %v", result) - } - t.Log("✅ 健康检查端点") - }) - - // 测试 2: GET /models - t.Run("GetModels", func(t *testing.T) { - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/api/v1/models", nil) - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - - var result map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &result) - if result["default_model"] != "claude-opus-4-6" { - t.Fatalf("Expected default_model, got %v", result) - } - t.Log("✅ 获取模型列表") - }) - - // 测试 3: POST /cascade/start - var cascadeID string - t.Run("StartCascade", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{ - "model": "claude-opus-4-6", - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - - var result map[string]string - json.Unmarshal(w.Body.Bytes(), &result) - cascadeID = result["cascade_id"] - if cascadeID == "" { - t.Fatalf("Expected cascade_id, got empty") - } - t.Logf("✅ 启动会话 (cascade_id=%s)", cascadeID) - }) - - // 测试 4: POST /cascade/cancel(使用从第3个测试获取的真实会话ID) - t.Run("CancelCascade", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{ - "cascade_id": cascadeID, - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/cancel", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - - var result map[string]string - json.Unmarshal(w.Body.Bytes(), &result) - if result["message"] != "cascade cancelled" { - t.Fatalf("Expected cascade cancelled message, got %v", result) - } - t.Log("✅ 取消会话") - }) - - // 测试 5: POST /cascade/message (SSE) - 验证响应头格式 - t.Run("SendMessage", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{ - "cascade_id": cascadeID, - "message": "Hello, world!", - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Expected 200, got %d", w.Code) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "text/event-stream" { - t.Fatalf("Expected text/event-stream, got %s", contentType) - } - t.Log("✅ 发送消息(SSE流式响应)") - }) - - t.Log("\n✅ 所有 Antigravity HTTP API 路由测试通过!") -} - -func TestStartCascadeValidation(t *testing.T) { - gin.SetMode(gin.TestMode) - - mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) - defer mockService.Stop() - - r := gin.New() - v1 := r.Group("/api/v1") - RegisterAntigravityHTTPRoutes(v1, mockService) - - t.Run("MissingModel", func(t *testing.T) { - w := httptest.NewRecorder() - body := []byte(`{"system_prompt":"test"}`) - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("Expected 400 for invalid request, got %d", w.Code) - } - t.Log("✅ 缺少必需字段验证") - }) - - t.Run("MissingAuthorization", func(t *testing.T) { - w := httptest.NewRecorder() - body := []byte(`{"model":"claude-opus-4-6"}`) - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - // 不设置 Authorization 头 - r.ServeHTTP(w, req) - - if w.Code != http.StatusUnauthorized { - t.Errorf("Expected 401 for missing auth, got %d", w.Code) - } - t.Log("✅ 缺少授权令牌验证") - }) - - t.Log("\n✅ 所有验证测试通过!") -} - -// TestRateLimiting 测试速率限制(改进 1) -func TestRateLimiting(t *testing.T) { - gin.SetMode(gin.TestMode) - - mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) - defer mockService.Stop() - - r := gin.New() - v1 := r.Group("/api/v1") - RegisterAntigravityHTTPRoutes(v1, mockService) - - // 创建一个会话 - startBody, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"}) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(startBody)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - var startResult map[string]string - json.Unmarshal(w.Body.Bytes(), &startResult) - cascadeID := startResult["cascade_id"] - - // 并发发送 150 个消息,应该有的超过限制 - var wg sync.WaitGroup - results := make([]int, 0) - var resultsMutex sync.Mutex - - for i := 0; i < 150; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - body, _ := json.Marshal(map[string]string{ - "cascade_id": cascadeID, - "message": "Test message " + string(rune(idx)), - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - resultsMutex.Lock() - results = append(results, w.Code) - resultsMutex.Unlock() - }(i) - } - - wg.Wait() - - // 统计结果 - successCount := 0 - timeoutCount := 0 - for _, code := range results { - if code == 200 || code == 500 { // 500 可能是上游 API 错误 - successCount++ - } else if code == 504 { // 网关超时 - timeoutCount++ - } - } - - // 预期:大部分请求成功(因为有速率限制),但速率限制应该生效 - // 限制是 100 并发,所以 150 个请求中应该都能处理(只是可能有等待) - if successCount < 140 { - t.Logf("⚠️ 仅 %d/150 个请求成功(超过限制被拒绝)- 这是预期的速率限制行为", successCount) - } - - t.Logf("✅ 速率限制测试完成:成功=%d, 超时=%d", successCount, timeoutCount) -} - -// TestSessionCleanup 测试会话超时清理(改进 3) -func TestSessionCleanup(t *testing.T) { - gin.SetMode(gin.TestMode) - - mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) - mockService.SetSessionTTL(2) // 设置 2 秒过期,便于测试 - defer mockService.Stop() - - r := gin.New() - v1 := r.Group("/api/v1") - RegisterAntigravityHTTPRoutes(v1, mockService) - - // 创建 5 个会话 - cascadeIDs := make([]string, 5) - for i := 0; i < 5; i++ { - body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"}) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - var result map[string]string - json.Unmarshal(w.Body.Bytes(), &result) - cascadeIDs[i] = result["cascade_id"] - } - - // 验证所有会话存在 - sessions := mockService.GetCascadeSessions() - if len(sessions) != 5 { - t.Fatalf("Expected 5 sessions, got %d", len(sessions)) - } - t.Log("✅ 创建了 5 个会话") - - // 等待清理周期 + TTL - time.Sleep(3 * time.Second) - - // 验证会话被清理 - sessions = mockService.GetCascadeSessions() - sessionCount := len(sessions) - - if sessionCount != 0 { - t.Logf("⚠️ 预期 0 个会话,但仍有 %d 个(可能清理还未执行)", sessionCount) - } else { - t.Log("✅ 过期会话成功清理") - } -} - -// TestConcurrentMessageAppend 测试并发安全的消息追加(改进 2) -func TestConcurrentMessageAppend(t *testing.T) { - gin.SetMode(gin.TestMode) - - mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil) - defer mockService.Stop() - - r := gin.New() - v1 := r.Group("/api/v1") - RegisterAntigravityHTTPRoutes(v1, mockService) - - // 创建会话 - body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"}) - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - var result map[string]string - json.Unmarshal(w.Body.Bytes(), &result) - cascadeID := result["cascade_id"] - - // 并发追加 50 个消息 - var wg sync.WaitGroup - for i := 0; i < 50; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - body, _ := json.Marshal(map[string]string{ - "cascade_id": cascadeID, - "message": "Concurrent message " + string(rune(idx)), - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - r.ServeHTTP(w, req) - - // 不关心返回值,只关心不 panic - }(i) - } - - wg.Wait() - - // 验证会话中的消息数量 - sessions := mockService.GetCascadeSessions() - messageCount := 0 - if session, exists := sessions[cascadeID]; exists { - messageCount = len(session.Messages) - } - - // 预期:1 个初始消息(如果没有 system_prompt,则为 0)+ 最多 50 个用户消息 - // 但由于速率限制,可能不是所有 50 个都会被处理 - if messageCount > 0 { - t.Logf("✅ 并发消息追加成功,共 %d 条消息", messageCount) - } else { - t.Log("⚠️ 由于速率限制或其他原因,部分消息未被追加") - } -} diff --git a/backend/internal/service/antigravity_account68_e2e_test.go b/backend/internal/service/antigravity_account68_e2e_test.go deleted file mode 100644 index ec1dbd2a..00000000 --- a/backend/internal/service/antigravity_account68_e2e_test.go +++ /dev/null @@ -1,254 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" -) - -// TestAccount68FullE2E 测试账号 68 的完整端到端流程 -// 模拟: curl POST /api/v1/admin/accounts/68/test -func TestAccount68FullE2E(t *testing.T) { - t.Log("🔥 测试账号 68 的完整认证流程...") - t.Log("") - - // 准备账号数据(与云端数据一致) - account := &Account{ - ID: 68, - Name: "PriesJosephe139@gmail.com", - Platform: PlatformAntigravity, - Type: "oauth", - Credentials: map[string]interface{}{ - "_token_version": 1775902256706, - "access_token": "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213", - "email": "priesjosephe139@gmail.com", - "expires_at": "1775907556", - "model_mapping": map[string]interface{}{ - "claude-opus-*": "claude-opus-4-6-thinking", - "claude-sonnet-*": "claude-sonnet-4-6-thinking", - }, - "plan_type": "Free", - "project_id": "kinetic-sum-r3tp7", - "refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo", - "token_type": "Bearer", - }, - Extra: map[string]interface{}{ - "allow_overages": true, - "privacy_mode": "privacy_set", - }, - ProxyID: ptrInt64(9), - Concurrency: 100, - Priority: 1, - Status: "active", - } - - t.Log("📌 账号信息:") - t.Logf(" ID: %d", account.ID) - t.Logf(" Name: %s", account.Name) - t.Logf(" Platform: %s", account.Platform) - t.Logf(" Project ID: %v", account.GetCredential("project_id")) - t.Log("") - - // 步骤 1: 验证凭证 - t.Run("Step1_ValidateCredentials", func(t *testing.T) { - t.Log("步骤 1: 验证账号凭证...") - - accessToken := account.GetCredential("access_token") - if accessToken == "" { - t.Fatalf("❌ Access token 为空") - } - t.Logf(" ✓ Access Token 存在 (长度: %d)", len(accessToken)) - - projectID := account.GetCredential("project_id") - if projectID == "" { - t.Fatalf("❌ Project ID 为空") - } - t.Logf(" ✓ Project ID 存在: %s", projectID) - - t.Log("") - }) - - // 步骤 2: 测试 API 调用(通过 SOCKS5 代理) - t.Run("Step2_CallUpstreamAPI", func(t *testing.T) { - t.Log("步骤 2: 通过 SOCKS5 代理调用上游 API...") - t.Log("") - - ctx, cancel := context.WithTimeout(context.Background(), 30) - defer cancel() - - // 使用之前测试过的配置 - proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760" - accessTokenStr := account.GetCredential("access_token") - - t.Logf(" 📤 API 请求:") - t.Logf(" URL: https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist") - t.Logf(" Token: %s... (长度: %d)", accessTokenStr[:30], len(accessTokenStr)) - t.Logf(" Proxy: %s", proxyAddr) - t.Log("") - - // 创建 HTTP 客户端(使用 SOCKS5 代理) - transport := &http.Transport{} - - httpClient := &http.Client{ - Transport: transport, - Timeout: 30, - } - - req, err := http.NewRequestWithContext(ctx, "POST", - "https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist", - bytes.NewReader([]byte(`{}`))) - if err != nil { - t.Fatalf("❌ 创建请求失败: %v", err) - } - - req.Header.Set("Authorization", "Bearer "+accessTokenStr) - req.Header.Set("Content-Type", "application/json") - - resp, err := httpClient.Do(req) - if err != nil { - t.Logf("❌ API 调用失败: %v", err) - t.Logf(" (可能是网络问题,但凭证本身没问题)") - return - } - defer resp.Body.Close() - - t.Logf(" ✓ 收到响应") - t.Logf(" HTTP Status: %d", resp.StatusCode) - t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type")) - t.Log("") - - // 读取响应 - respBody := make([]byte, 2048) - n, _ := resp.Body.Read(respBody) - respText := string(respBody[:n]) - - if resp.StatusCode == 200 { - t.Log(" ✅ API 调用成功!") - var result map[string]interface{} - if err := json.Unmarshal(respBody[:n], &result); err == nil { - if _, ok := result["cloudaicompanionProject"]; ok { - t.Logf(" ✓ 获得 Project: %v", result["cloudaicompanionProject"]) - } - } - } else { - t.Logf(" ❌ API 返回错误 (HTTP %d)", resp.StatusCode) - t.Logf(" 响应: %s", respText) - } - t.Log("") - }) - - // 步骤 3: 模拟 SSE 响应流(本地) - t.Run("Step3_SimulateSSEResponse", func(t *testing.T) { - t.Log("步骤 3: 模拟 SSE 响应流...") - t.Log("") - - gin.SetMode(gin.TestMode) - router := gin.New() - - // 模拟成功的 API 响应 - successResponse := map[string]interface{}{ - "cloudaicompanionProject": "kinetic-sum-r3tp7", - "currentTier": map[string]interface{}{ - "id": "free-tier", - "name": "Antigravity", - }, - } - - router.POST("/test", func(c *gin.Context) { - // 设置 SSE 头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Status(200) - - // 发送测试开始 - event1 := map[string]interface{}{ - "type": "test_start", - "model": "claude-opus-4-6", - } - data1, _ := json.Marshal(event1) - c.Writer.WriteString("data: " + string(data1) + "\n\n") - c.Writer.Flush() - - // 发送内容(成功的 API 响应) - event2 := map[string]interface{}{ - "type": "content", - "text": "✅ 账号验证成功!", - } - data2, _ := json.Marshal(event2) - c.Writer.WriteString("data: " + string(data2) + "\n\n") - c.Writer.Flush() - - // 发送完成 - event3 := map[string]interface{}{ - "type": "test_complete", - "success": true, - } - data3, _ := json.Marshal(event3) - c.Writer.WriteString("data: " + string(data3) + "\n\n") - c.Writer.Flush() - - t.Logf(" 📤 服务器已发送 SSE 事件:") - t.Logf(" 1. test_start (model=%v)", successResponse["cloudaicompanionProject"]) - t.Logf(" 2. content (text: ✅ 账号验证成功!)") - t.Logf(" 3. test_complete (success=true)") - }) - - // 发送请求 - req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte(`{}`))) - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - // 验证响应 - t.Log("") - t.Log(" 📥 客户端收到的响应:") - body := w.Body.String() - lines := bytes.Split([]byte(body), []byte("\n\n")) - for i, line := range lines { - if len(line) == 0 { - continue - } - if bytes.HasPrefix(line, []byte("data: ")) { - data := bytes.TrimPrefix(line, []byte("data: ")) - var event map[string]interface{} - if err := json.Unmarshal(data, &event); err == nil { - t.Logf(" 事件 %d: type=%v", i, event["type"]) - if content, ok := event["content"]; ok { - t.Logf(" content=%v", content) - } - if success, ok := event["success"]; ok { - t.Logf(" success=%v", success) - } - } - } - } - t.Log("") - }) - - // 步骤 4: 总结 - t.Run("Step4_Summary", func(t *testing.T) { - t.Log("步骤 4: 总结...") - t.Log("") - t.Log("✅ 账号 68 测试完成!") - t.Log("") - t.Log("🎯 关键发现:") - t.Log(" 1. Access Token 已刷新成功 ✅") - t.Log(" 2. Project ID 有效: kinetic-sum-r3tp7 ✅") - t.Log(" 3. 上游 Google API 返回 200 成功 ✅") - t.Log(" 4. SSE 事件正确传递 ✅") - t.Log("") - t.Log("📊 预期结果:") - t.Log(" - 云端测试应该也能成功") - t.Log(" - 不再看到 'IT' 错误") - t.Log("") - }) -} - -func ptrInt64(i int64) *int64 { - return &i -} diff --git a/backend/internal/service/antigravity_direct_upstream_test.go b/backend/internal/service/antigravity_direct_upstream_test.go deleted file mode 100644 index 193ac051..00000000 --- a/backend/internal/service/antigravity_direct_upstream_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -// TestDirectUpstreamCall 直接调用真实的 Google API,看返回什么 -func TestDirectUpstreamCall(t *testing.T) { - t.Log("🔥 直接调用 Google API,观察真实返回值...") - t.Log("") - - accessToken := "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213" - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // 步骤 1: 创建客户端 - t.Log("步骤 1: 创建 Antigravity 客户端...") - client, err := antigravity.NewClient("") - if err != nil { - t.Fatalf("❌ 创建客户端失败: %v", err) - } - t.Log("✅ 客户端创建成功") - t.Log("") - - // 步骤 2: 直接调用 LoadCodeAssist - t.Log("步骤 2: 调用 client.LoadCodeAssist(ctx, accessToken)...") - t.Logf(" AccessToken: %s... (长度: %d)", accessToken[:30], len(accessToken)) - t.Log("") - - resp, rawResp, err := client.LoadCodeAssist(ctx, accessToken) - - // 步骤 3: 分析返回值 - t.Log("步骤 3: 分析返回值...") - t.Log("") - - if err != nil { - t.Logf("❌ 调用失败") - t.Logf(" 错误类型: %T", err) - t.Logf(" 错误信息: %v", err) - t.Logf(" 错误字符串: %s", err.Error()) - t.Logf(" 错误长度: %d 字符", len(err.Error())) - t.Log("") - - // 分析错误信息的前几个字符 - errStr := err.Error() - if len(errStr) >= 2 { - t.Logf("📊 错误信息的前 5 个字符: '%s'", errStr[:min(5, len(errStr))]) - } - t.Log("") - - t.Logf("🎯 这就是导致 'IT' 错误的真实原因!") - t.Logf(" 错误完整内容: %q", errStr) - t.Log("") - - // 尝试找出 "IT" 的来源 - if len(errStr) >= 2 { - first2 := errStr[:2] - t.Logf("📌 错误的前两个字符: '%s'", first2) - if first2 == "IT" { - t.Logf(" ✓ 确认: 'IT' 就是从这个错误截断来的") - } else { - t.Logf(" ⚠️ 前两个字符不是 'IT',可能被其他方式处理了") - } - } - return - } - - // 成功的情况 - t.Log("✅ 调用成功!") - t.Log("") - - if resp != nil { - t.Logf("📋 响应信息:") - t.Logf(" CloudAICompanionProject: %s", resp.CloudAICompanionProject) - t.Logf(" Response 类型: %T", resp) - t.Log("") - - // 打印原始响应 - if rawResp != nil { - t.Log("📄 原始 API 响应 JSON:") - jsonBytes, _ := json.MarshalIndent(rawResp, " ", " ") - t.Logf("%s", string(jsonBytes)) - } - } -} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index be33bf93..a76e59fb 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -44,10 +44,9 @@ const ( // MODEL_CAPACITY_EXHAUSTED 专用重试参数 // 模型容量不足时,所有账号共享同一容量池,切换账号无意义 - // 使用指数退避策略重试,最多重试 10 次(而非 60 次) - antigravityModelCapacityRetryMaxAttempts = 10 + // 使用固定 1s 间隔重试,最多重试 60 次 + antigravityModelCapacityRetryMaxAttempts = 60 antigravityModelCapacityRetryWait = 1 * time.Second - antigravityModelCapacityRetryMaxWait = 32 * time.Second // 指数退避上限 // Google RPC 状态和类型常量 googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" @@ -113,62 +112,6 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError, return nil, false } -func isGoogleOneAICreditsEntry(entry map[string]any) bool { - creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string) - creditType = strings.TrimSpace(strings.ToUpper(creditType)) - return creditType == "" || creditType == "GOOGLE_ONE_AI" -} - -func firstPresent(entry map[string]any, keys ...string) any { - for _, key := range keys { - if value, ok := entry[key]; ok { - return value - } - } - return nil -} - -func parseAICreditsInt32(raw any) (int32, bool) { - switch v := raw.(type) { - case int: - return int32(v), true - case int32: - return v, true - case int64: - return int32(v), true - case float32: - return int32(v), true - case float64: - return int32(v), true - case json.Number: - parsed, err := v.Int64() - if err != nil { - floatVal, floatErr := strconv.ParseFloat(v.String(), 64) - if floatErr != nil { - return 0, false - } - return int32(floatVal), true - } - return int32(parsed), true - case string: - trimmed := strings.TrimSpace(v) - if trimmed == "" { - return 0, false - } - parsed, err := strconv.ParseInt(trimmed, 10, 32) - if err == nil { - return int32(parsed), true - } - floatVal, floatErr := strconv.ParseFloat(trimmed, 64) - if floatErr != nil { - return 0, false - } - return int32(floatVal), true - default: - return 0, false - } -} - // PromptTooLongError 表示上游明确返回 prompt too long type PromptTooLongError struct { StatusCode int @@ -206,28 +149,17 @@ type antigravityRetryLoopResult struct { } // resolveAntigravityForwardBaseURL 解析转发用 base URL。 -// 根据账号类型选择优先 URL:企业账号(isGcpTos=true)→ prod;个人账号 → daily(与真实 IDE 一致)。 -// 可通过环境变量 GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL=daily 或 =prod 强制覆盖。 -func resolveAntigravityForwardBaseURL(account *Account) string { - mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv))) - if mode == "daily" { - return "https://daily-cloudcode-pa.googleapis.com" - } - if mode == "prod" { - return "https://cloudcode-pa.googleapis.com" - } - // 按账号类型选择优先 URL - isGcpTos := account != nil && account.GetCredentialAsBool("is_gcp_tos") - urls := antigravity.BaseURLsForAccount(isGcpTos) - if len(urls) == 0 { +// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。 +func resolveAntigravityForwardBaseURL() string { + baseURLs := antigravity.ForwardBaseURLs() + if len(baseURLs) == 0 { return "" } - // 返回可用列表中的第一个(URLAvailability 动态优先级在调用方处理) - available := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(urls) - if len(available) > 0 { - return available[0] + mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv))) + if mode == "prod" && len(baseURLs) > 1 { + return baseURLs[1] } - return urls[0] + return baseURLs[0] } // smartRetryAction 智能重试的处理结果 @@ -319,7 +251,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam var lastRetryResp *http.Response var lastRetryBody []byte - // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(10 次,指数退避) + // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔) maxAttempts := antigravitySmartRetryMaxAttempts if isModelCapacityExhausted { maxAttempts = antigravityModelCapacityRetryMaxAttempts @@ -346,29 +278,10 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } for attempt := 1; attempt <= maxAttempts; attempt++ { - // 计算本次重试的等待时间 - var currentWaitDuration time.Duration - if isModelCapacityExhausted { - // 使用指数退避:1s, 2s, 4s, 8s, 16s, 32s, ... - currentWaitDuration = waitDuration * time.Duration(1<<(attempt-1)) - if currentWaitDuration > antigravityModelCapacityRetryMaxWait { - currentWaitDuration = antigravityModelCapacityRetryMaxWait - } - // 添加随机抖动(±10%)避免羊群效应 - jitter := time.Duration(mathrand.Int63n(int64(currentWaitDuration / 5))) - if mathrand.Intn(2) == 0 { - currentWaitDuration += jitter - } else { - currentWaitDuration -= jitter - } - } else { - currentWaitDuration = waitDuration - } - log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", - p.prefix, resp.StatusCode, attempt, maxAttempts, currentWaitDuration, modelName, p.account.ID) + p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID) - timer := time.NewTimer(currentWaitDuration) + timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): timer.Stop() @@ -678,7 +591,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP } } - baseURL := resolveAntigravityForwardBaseURL(p.account) + baseURL := resolveAntigravityForwardBaseURL() if baseURL == "" { return nil, errors.New("no antigravity forward base url configured") } @@ -1084,20 +997,13 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo return mapAntigravityModel(account, requestedModel) } -// applyThinkingModelSuffix 根据 thinking 配置和模型可用性调整模型名。 -// Google v1internal API 上部分 Claude 模型只有 -thinking 后缀版本存在, -// 非 -thinking 版本会返回 404。 +// applyThinkingModelSuffix 根据 thinking 配置调整模型名 +// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string { - // claude-opus-4-6: Google API 上只有 -thinking 版本,始终加后缀 - if mappedModel == "claude-opus-4-6" { - return "claude-opus-4-6-thinking" - } - // 其他模型仅在 thinking 开启时加后缀 if !thinkingEnabled { return mappedModel } - switch mappedModel { - case "claude-sonnet-4-5": + if mappedModel == "claude-sonnet-4-5" { return "claude-sonnet-4-5-thinking" } return mappedModel @@ -1139,10 +1045,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("model %s not in whitelist", modelID) } - // 应用 thinking 后缀(claude-opus-4-6 → claude-opus-4-6-thinking) - // TestConnection 与主请求路径保持一致:Google API 只支持 -thinking 后缀版本的部分模型 - mappedModel = applyThinkingModelSuffix(mappedModel, false) - // 构建请求体 var requestBody []byte if strings.HasPrefix(modelID, "gemini-") { @@ -1246,17 +1148,17 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri } // buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式 -// 使用最小 token 消耗:输入 "." + MaxTokens: 10(足够验证连接) +// 使用最小 token 消耗:输入 "." + MaxTokens: 1 func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) { claudeReq := &antigravity.ClaudeRequest{ Model: mappedModel, Messages: []antigravity.ClaudeMessage{ { Role: "user", - Content: json.RawMessage(`"Test connection"`), + Content: json.RawMessage(`"."`), }, }, - MaxTokens: 10, + MaxTokens: 1, Stream: false, } return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) @@ -1387,19 +1289,9 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) { } // wrapV1InternalRequest 包装请求为 v1internal 格式 -func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) { - sessionID := "" - if len(preferredSessionID) > 0 { - sessionID = preferredSessionID[0] - } - - bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID) - if err != nil { - return nil, fmt.Errorf("补全 sessionId 失败: %w", err) - } - +func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) { var request any - if err := json.Unmarshal(bodyWithSessionID, &request); err != nil { + if err := json.Unmarshal(originalBody, &request); err != nil { return nil, fmt.Errorf("解析请求体失败: %w", err) } @@ -1477,6 +1369,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } + // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) billingModel := mappedModel @@ -1503,9 +1396,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } // 获取转换选项 + // Antigravity 上游要求必须包含身份提示词,否则会返回 429 transformOpts := s.getClaudeTransformOptions(ctx) - transformOpts.EnableIdentityPatch = true - transformOpts.PreferredSessionID = sessionID + transformOpts.EnableIdentityPatch = true // 强制启用,Antigravity 上游必需 // 转换 Claude 请求为 Gemini 格式 geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts) @@ -1513,8 +1406,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request") } + // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent + // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, prefix: prefix, @@ -1529,17 +1425,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accountRepo: s.accountRepo, handleError: s.handleUpstreamError, requestedModel: originalModel, - isStickySession: isStickySession, - groupID: 0, - sessionHash: "", + isStickySession: isStickySession, // Forward 由上层判断粘性会话 + groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 + sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 }) if err != nil { + // 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号 if switchErr, ok := IsAntigravityAccountSwitchError(err); ok { return nil, &UpstreamFailoverError{ StatusCode: http.StatusServiceUnavailable, ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 if c.Request.Context().Err() != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response") } @@ -1551,6 +1449,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + // 优先检测 thinking block 的 signature 相关错误(400)并重试一次: + // Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验, + // 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。 if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1567,6 +1468,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: upstreamDetail, }) + // Conservative two-stage fallback: + // 1) Disable top-level thinking + thinking->text + // 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text. + retryStages := []struct { name string strip func(*antigravity.ClaudeRequest) (bool, error) @@ -1586,7 +1491,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) - retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts) + retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx)) if txErr != nil { continue } @@ -1605,8 +1510,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, handleError: s.handleUpstreamError, requestedModel: originalModel, isStickySession: isStickySession, - groupID: 0, - sessionHash: "", + groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除 + sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除 }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1659,6 +1564,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: retryUpstreamDetail, }) + // If this stage fixed the signature issue, we stop; otherwise we may try the next stage. if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) { respBody = retryBody resp = &http.Response{ @@ -1669,6 +1575,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, break } + // Still signature-related; capture context and allow next stage. respBody = retryBody resp = &http.Response{ StatusCode: retryResp.StatusCode, @@ -1678,7 +1585,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - // Budget 整流 + // Budget 整流:检测 budget_tokens 约束错误并自动修正重试 if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) { errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { @@ -1693,9 +1600,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Detail: s.getUpstreamErrorDetail(respBody), }) + // 修正 claudeReq 的 thinking 参数(adaptive 模式不修正) if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" { retryClaudeReq := claudeReq retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) + // 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针 retryClaudeReq.Thinking = &antigravity.ThinkingConfig{ Type: "enabled", BudgetTokens: BudgetRectifyBudgetTokens, @@ -1750,7 +1659,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } + // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { + // 检测 prompt too long 错误,返回特殊错误类型供上层 fallback if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1778,6 +1689,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) + // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover if resp.StatusCode == http.StatusBadRequest { msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) if isGoogleProjectConfigError(msg) { @@ -1828,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var firstTokenMs *int var clientDisconnect bool if claudeReq.Stream { + // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) @@ -1837,6 +1750,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, firstTokenMs = streamRes.firstTokenMs clientDisconnect = streamRes.clientDisconnect } else { + // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) if err != nil { logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) @@ -1846,13 +1760,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, firstTokenMs = streamRes.firstTokenMs } - // DEBUG: 追踪 OAuth Claude 路径的 Usage 在 Forward 返回点的值。 - // 若这里 output>0 而 DB 记录为 0,说明 bug 在下游(billing/record 层); - // 若这里 output=0,说明 bug 在 handleClaudeStreamingResponse 或更上游。 - logger.LegacyPrintf("service.antigravity_gateway", - "%s DEBUG_USAGE_FORWARD_RETURN input=%d output=%d cache_read=%d cache_creation=%d stream=%v model=%s account=%d", - prefix, usage.InputTokens, usage.OutputTokens, usage.CacheReadInputTokens, usage.CacheCreationInputTokens, claudeReq.Stream, originalModel, account.ID) - return &ForwardResult{ RequestID: requestID, Usage: *usage, @@ -2249,7 +2156,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } // 包装请求 - wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID) + wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody) if err != nil { return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request") } @@ -2313,7 +2220,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if fallbackModel != "" && fallbackModel != mappedModel { logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID) + fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody) if err == nil { fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped) if err == nil { @@ -2356,7 +2263,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID) cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody) - retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID) + retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody) if wrapErr == nil { retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -3124,45 +3031,6 @@ func handleStreamReadError(err error, clientDisconnected bool, prefix string) (d return false, false } -func googleStatusTextForHTTP(status int) string { - switch status { - case http.StatusBadRequest: - return "INVALID_ARGUMENT" - case http.StatusNotFound: - return "NOT_FOUND" - case http.StatusTooManyRequests: - return "RESOURCE_EXHAUSTED" - case http.StatusServiceUnavailable: - return "UNAVAILABLE" - default: - return "UNKNOWN" - } -} - -func buildAnthropicStreamErrorEvent(errType, message string) string { - payload := map[string]any{ - "type": "error", - "error": map[string]any{ - "type": errType, - "message": message, - }, - } - data, _ := json.Marshal(payload) - return "event: error\ndata: " + string(data) + "\n\n" -} - -func buildGeminiStreamErrorEvent(status int, message string) string { - payload := map[string]any{ - "error": map[string]any{ - "code": status, - "message": message, - "status": googleStatusTextForHTTP(status), - }, - } - data, _ := json.Marshal(payload) - return "event: error\ndata: " + string(data) + "\n\n" -} - func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { c.Status(resp.StatusCode) c.Header("Cache-Control", "no-cache") @@ -3258,12 +3126,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false - sendErrorEvent := func(status int, message string) { + sendErrorEvent := func(reason string) { if errorEventSent || cw.Disconnected() { return } errorEventSent = true - _, _ = fmt.Fprint(c.Writer, buildGeminiStreamErrorEvent(status, message)) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) flusher.Flush() } @@ -3279,10 +3147,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) - sendErrorEvent(http.StatusBadGateway, "Response too large") + sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } - sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream read failed") + sendErrorEvent("stream_read_error") return nil, ev.err } @@ -3345,7 +3213,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") - sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream timeout") + sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: @@ -4105,12 +3973,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false - sendErrorEvent := func(errType, message string) { + sendErrorEvent := func(reason string) { if errorEventSent || cw.Disconnected() { return } errorEventSent = true - _, _ = fmt.Fprint(c.Writer, buildAnthropicStreamErrorEvent(errType, message)) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) flusher.Flush() } @@ -4126,9 +3994,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if !ok { // 上游完成,发送结束事件 finalEvents, agUsage := processor.Finish() - logger.LegacyPrintf("service.antigravity_gateway", - "DEBUG_USAGE_PROCESSOR_FINISH input=%d output=%d cache_read=%d image_output=%d final_events_len=%d", - agUsage.InputTokens, agUsage.OutputTokens, agUsage.CacheReadInputTokens, agUsage.ImageOutputTokens, len(finalEvents)) if len(finalEvents) > 0 { cw.Write(finalEvents) } else if !processor.MessageStartSent() && !cw.Disconnected() { @@ -4145,15 +4010,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } if ev.err != nil { if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled { - logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=handleStreamReadError disconnect=%v", disconnect) return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=ErrTooLong max_size=%d error=%v (usage WILL BE ZEROED)", maxLineSize, ev.err) - sendErrorEvent("api_error", "Response too large") + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err } - sendErrorEvent("api_error", "Upstream stream read failed") + sendErrorEvent("stream_read_error") return nil, fmt.Errorf("stream read error: %w", ev.err) } @@ -4179,7 +4043,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") - sendErrorEvent("api_error", "Upstream stream timeout") + sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: @@ -4672,61 +4536,3 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage } return usage } - -// ForwardRaw 转发 Claude 格式请求并返回原始上游响应体(调用者负责关闭)。 -// 不依赖 gin.Context,供内部服务(如 LanguageServerService)调用。 -// 复用完整的 token 刷新、模型映射、TLS 指纹和重试逻辑。 -func (s *AntigravityGatewayService) ForwardRaw(ctx context.Context, account *Account, body []byte) (io.ReadCloser, int, error) { - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err) - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, http.StatusBadRequest, fmt.Errorf("missing model") - } - - mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel == "" { - return nil, http.StatusForbidden, fmt.Errorf("model %s not in whitelist", claudeReq.Model) - } - thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") - mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - - if s.tokenProvider == nil { - return nil, http.StatusBadGateway, fmt.Errorf("antigravity token provider not configured") - } - accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) - if err != nil { - return nil, http.StatusBadGateway, fmt.Errorf("failed to get access token: %w", err) - } - - projectID := strings.TrimSpace(account.GetCredential("project_id")) - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - transformOpts := s.getClaudeTransformOptions(ctx) - transformOpts.EnableIdentityPatch = true - geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts) - if err != nil { - return nil, http.StatusBadRequest, fmt.Errorf("failed to transform request: %w", err) - } - - wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, geminiBody) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("failed to wrap request: %w", err) - } - - upstreamReq, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, wrappedBody) - if err != nil { - return nil, http.StatusInternalServerError, fmt.Errorf("failed to build upstream request: %w", err) - } - - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, http.StatusBadGateway, fmt.Errorf("upstream request failed: %w", err) - } - - return resp.Body, resp.StatusCode, nil -} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index bba6d1ee..1eb1451e 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -600,120 +600,6 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing require.Equal(t, mappedModel, result.UpstreamModel) } -func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) { - gin.SetMode(gin.TestMode) - writer := httptest.NewRecorder() - c, _ := gin.CreateTestContext(writer) - - body, err := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, - }, - }) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("session_id", "session-header-1") - c.Request = req - - upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") - upstream := &queuedHTTPUpstreamStub{ - responses: []*http.Response{ - { - StatusCode: http.StatusOK, - Header: http.Header{"X-Request-Id": []string{"req-session-1"}}, - Body: io.NopCloser(bytes.NewReader(upstreamBody)), - }, - }, - } - - svc := &AntigravityGatewayService{ - settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: upstream, - } - - account := &Account{ - ID: 16, - Name: "acc-gemini-session", - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "token", - }, - } - - result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Len(t, upstream.requestBodies, 1) - - var wrapped map[string]any - require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped)) - requestNode, ok := wrapped["request"].(map[string]any) - require.True(t, ok) - require.Equal(t, "session-header-1", requestNode["sessionId"]) -} - -func TestAntigravityGatewayService_Forward_PropagatesSessionHeaderIntoClaudeTransform(t *testing.T) { - gin.SetMode(gin.TestMode) - writer := httptest.NewRecorder() - c, _ := gin.CreateTestContext(writer) - - body := []byte(`{ - "model":"claude-sonnet-4-5", - "max_tokens":64, - "messages":[ - {"role":"user","content":[{"type":"text","text":"hello"}]} - ] - }`) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("session_id", "session-header-1") - c.Request = req - - upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") - upstream := &queuedHTTPUpstreamStub{ - responses: []*http.Response{ - { - StatusCode: http.StatusOK, - Header: http.Header{"X-Request-Id": []string{"req-session-claude-1"}}, - Body: io.NopCloser(bytes.NewReader(upstreamBody)), - }, - }, - } - - svc := &AntigravityGatewayService{ - settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: upstream, - } - - account := &Account{ - ID: 17, - Name: "acc-claude-session", - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "token", - "project_id": "project-1", - }, - } - - result, err := svc.Forward(context.Background(), c, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Len(t, upstream.requestBodies, 1) - - var wrapped antigravity.V1InternalRequest - require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped)) - require.Equal(t, "session-header-1", wrapped.Request.SessionID) -} - func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index 99081424..3a4600db 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -29,9 +29,8 @@ type AntigravityAuthURLResult struct { State string `json:"state"` } -// GenerateAuthURL 生成 Google OAuth 授权链接。 -// isEnterprise=true 时生成企业账号授权链接(使用企业 client_id)。 -func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, isEnterprise bool) (*AntigravityAuthURLResult, error) { +// GenerateAuthURL 生成 Google OAuth 授权链接 +func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) { state, err := antigravity.GenerateState() if err != nil { return nil, fmt.Errorf("生成 state 失败: %w", err) @@ -59,13 +58,12 @@ func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID * State: state, CodeVerifier: codeVerifier, ProxyURL: proxyURL, - IsEnterprise: isEnterprise, CreatedAt: time.Now(), } s.sessionStore.Set(sessionID, session) codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier) - authURL := antigravity.BuildAuthorizationURL(state, codeChallenge, isEnterprise) + authURL := antigravity.BuildAuthorizationURL(state, codeChallenge) return &AntigravityAuthURLResult{ AuthURL: authURL, @@ -91,7 +89,6 @@ type AntigravityTokenInfo struct { TokenType string `json:"token_type"` Email string `json:"email,omitempty"` ProjectID string `json:"project_id,omitempty"` - IsEnterprise bool `json:"is_enterprise,omitempty"` ProjectIDMissing bool `json:"-"` PlanType string `json:"-"` PrivacyMode string `json:"-"` @@ -122,8 +119,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig return nil, fmt.Errorf("create antigravity client failed: %w", err) } - // 交换 token(使用 session 中记录的账号类型) - tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier, session.IsEnterprise) + // 交换 token + tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) if err != nil { return nil, fmt.Errorf("token 交换失败: %w", err) } @@ -140,7 +137,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig ExpiresIn: tokenResp.ExpiresIn, ExpiresAt: expiresAt, TokenType: tokenResp.TokenType, - IsEnterprise: session.IsEnterprise, } // 获取用户信息 @@ -170,9 +166,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig return result, nil } -// RefreshToken 刷新 token。 -// isEnterprise=true 时使用企业 OAuth client_id/secret。 -func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string, isEnterprise bool) (*AntigravityTokenInfo, error) { +// RefreshToken 刷新 token +func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) { var lastErr error for attempt := 0; attempt <= 3; attempt++ { @@ -188,7 +183,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken if err != nil { return nil, fmt.Errorf("create antigravity client failed: %w", err) } - tokenResp, err := client.RefreshToken(ctx, refreshToken, isEnterprise) + tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() expiresAt := now.Unix() + tokenResp.ExpiresIn - 300 @@ -200,7 +195,6 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken ExpiresIn: tokenResp.ExpiresIn, ExpiresAt: expiresAt, TokenType: tokenResp.TokenType, - IsEnterprise: isEnterprise, }, nil } @@ -217,9 +211,8 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr) } -// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)。 -// isEnterprise=true 时使用企业 OAuth client 刷新。 -func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64, isEnterprise bool) (*AntigravityTokenInfo, error) { +// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id) +func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) { var proxyURL string if proxyID != nil { proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) @@ -228,8 +221,8 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } } - // 刷新 token:先按调用方指定类型刷新;若报 client 不匹配再尝试另一侧。 - tokenInfo, err := s.refreshTokenAutoFallback(ctx, refreshToken, proxyURL, isEnterprise) + // 刷新 token + tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) if err != nil { return nil, err } @@ -281,32 +274,6 @@ func isNonRetryableAntigravityOAuthError(err error) bool { return false } -// isClientMismatchOAuthError 判断是否为 OAuth client 不匹配错误(用于触发个人/企业切换)。 -// 与 isNonRetryableAntigravityOAuthError 不同:这里只识别 client 相关错误,不包含 invalid_grant。 -func isClientMismatchOAuthError(err error) bool { - if err == nil { - return false - } - msg := err.Error() - return strings.Contains(msg, "invalid_client") || - strings.Contains(msg, "unauthorized_client") -} - -// refreshTokenAutoFallback 先按指定类型刷新;若遇 client 不匹配错误则切换到另一侧。 -// 本函数不承担网络层重试(由内部 RefreshToken 处理)。 -func (s *AntigravityOAuthService) refreshTokenAutoFallback(ctx context.Context, refreshToken, proxyURL string, preferEnterprise bool) (*AntigravityTokenInfo, error) { - tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, preferEnterprise) - if err == nil { - return tokenInfo, nil - } - if !isClientMismatchOAuthError(err) { - return nil, err - } - // 切换另一侧账号类型重试 - fmt.Printf("[AntigravityOAuth] client 不匹配,切换账号类型重试:%v → %v\n", preferEnterprise, !preferEnterprise) - return s.RefreshToken(ctx, refreshToken, proxyURL, !preferEnterprise) -} - // RefreshAccountToken 刷新账户的 token func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) { if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth { @@ -318,8 +285,6 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou return nil, fmt.Errorf("无可用的 refresh_token") } - isEnterprise := account.GetCredentialAsBool("is_gcp_tos") - var proxyURL string if account.ProxyID != nil { proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) @@ -328,7 +293,7 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou } } - tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, isEnterprise) + tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) if err != nil { return nil, err } @@ -495,7 +460,6 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity creds := map[string]any{ "access_token": tokenInfo.AccessToken, "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), - "is_gcp_tos": tokenInfo.IsEnterprise, } if tokenInfo.RefreshToken != "" { creds["refresh_token"] = tokenInfo.RefreshToken diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index e3b60a27..6ac6b8fa 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -5,12 +5,13 @@ package service import ( "bytes" "context" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/stretchr/testify/require" "io" "net/http" "strings" "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/stretchr/testify/require" ) // stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock diff --git a/backend/internal/service/antigravity_test_full_flow_test.go b/backend/internal/service/antigravity_test_full_flow_test.go deleted file mode 100644 index 363744e7..00000000 --- a/backend/internal/service/antigravity_test_full_flow_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package service - -import ( - "testing" -) - -// TestAntigravityFullFlow 完整流程测试 -// 模拟从 HTTP 处理器到最终响应的完整路径 -func TestAntigravityFullFlow(t *testing.T) { - t.Log("🔥 启动 Antigravity 完整流程测试...") - t.Log("") - - // 构造测试账号数据(使用提供的凭证) - proxyID := int64(9) - account := &Account{ - ID: 68, - Name: "PriesJosephe139@gmail.com", - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213", - "refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo", - "email": "priesjosephe139@gmail.com", - "expires_at": "1775903154", - "project_id": "kinetic-sum-r3tp7", - "plan_type": "Free", - }, - ProxyID: &proxyID, - Concurrency: 100, - } - - // 测试路由决策逻辑 - t.Run("RouteAntigravityTest", func(t *testing.T) { - // 验证账号类型,决定使用哪条路径 - t.Logf("📌 账号类型判断:") - t.Logf(" Platform: %s (期望: antigravity)", account.Platform) - t.Logf(" Type: %s (期望: oauth)", account.Type) - t.Logf("") - - // 模拟 routeAntigravityTest 的决策逻辑 - var testPath string - if account.Type == AccountTypeAPIKey { - testPath = "APIKey 路径 (Claude/Gemini 直接连接)" - } else if account.Platform == PlatformAntigravity { - testPath = "OAuth/Upstream 路径 (使用 AntigravityGatewayService.TestConnection)" - } else { - testPath = "未知路径 (❌ 错误)" - } - - t.Logf("✅ 将使用: %s", testPath) - t.Logf("") - }) - - // 测试完整的错误处理流程 - t.Run("ErrorHandlingPathway", func(t *testing.T) { - t.Logf("📋 错误处理流程图:") - t.Logf("") - t.Logf("1️⃣ HTTP Handler (account_handler.go:671)") - t.Logf(" ↓") - t.Logf(" accountTestService.TestAccountConnection()") - t.Logf(" ↓") - t.Logf("2️⃣ AccountTestService.routeAntigravityTest()") - t.Logf(" ├─ Platform check: antigravity ✓") - t.Logf(" ├─ Type check: oauth ✓") - t.Logf(" └─ Call: testAntigravityAccountConnection()") - t.Logf(" ↓") - t.Logf("3️⃣ AccountTestService.testAntigravityAccountConnection()") - t.Logf(" ├─ Send SSE 'test_start' event") - t.Logf(" ├─ Call: AntigravityGatewayService.TestConnection()") - t.Logf(" │ ├─ Get access token") - t.Logf(" │ ├─ Get project_id") - t.Logf(" │ ├─ Build request body") - t.Logf(" │ ├─ Call: antigravityRetryLoop()") - t.Logf(" │ │ ├─ Execute HTTP request to Google API") - t.Logf(" │ │ ├─ Parse response") - t.Logf(" │ │ └─ Handle errors (rate limit, auth, etc.)") - t.Logf(" │ └─ Return result or error") - t.Logf(" ├─ If error: sendErrorAndEnd(error_message)") - t.Logf(" ├─ If success: sendEvent('content', response_text)") - t.Logf(" └─ Send SSE 'test_complete' event") - t.Logf(" ↓") - t.Logf("4️⃣ Response to Client (SSE 流)") - t.Logf(" ├─ Content-Type: text/event-stream") - t.Logf(" ├─ Event: test_start") - t.Logf(" ├─ Event: content (或 error)") - t.Logf(" └─ Event: test_complete") - t.Logf("") - }) - - // 诊断 "IT" 错误的可能来源 - t.Run("DiagnoseITError", func(t *testing.T) { - t.Logf("🔍 分析 'IT' 错误可能的来源:") - t.Logf("") - t.Logf("❓ 场景 1: 错误被截断") - t.Logf(" 原始错误可能是:") - t.Logf(" - 'INVALID_TOKEN' → truncated to 'IT'") - t.Logf(" - 'INTERNAL_ERROR' → truncated to 'IT'") - t.Logf(" - 'INVALID_GRANT' → truncated to 'IT'") - t.Logf(" - 'INTERNAL_ERROR...' → first 2 chars 'IN' not 'IT'") - t.Logf("") - t.Logf("❓ 场景 2: 错误来自特定的代码点") - t.Logf(" 可能出现 'IT' 的地方:") - t.Logf(" - SSE stream 中的错误字符") - t.Logf(" - HTTP response body 中的 JSON 解析错误") - t.Logf(" - Google API 返回的错误代码 (如果 Google API 返回 'IT' 作为错误)") - t.Logf("") - t.Logf("❓ 场景 3: 特殊的错误代码") - t.Logf(" 需要检查:") - t.Logf(" - 是否存在名为 'IT' 的错误常量?") - t.Logf(" - Google RPC 状态码中是否有 'IT'?") - t.Logf(" - 特定的错误处理中是否会生成 'IT'?") - t.Logf("") - }) - - // 完整的调试检查清单 - t.Run("DebugChecklist", func(t *testing.T) { - t.Logf("✅ 完整的调试检查清单:") - t.Logf("") - t.Logf("1. 验证账号信息:") - t.Logf(" [ ] Account ID: %d", account.ID) - t.Logf(" [ ] Platform: %s", account.Platform) - t.Logf(" [ ] Type: %s", account.Type) - t.Logf(" [ ] Access Token: %s... (长度: %d)", - account.GetCredential("access_token")[:20], - len(account.GetCredential("access_token"))) - t.Logf(" [ ] Project ID: %s", account.GetCredential("project_id")) - t.Logf("") - t.Logf("2. 验证请求路径:") - t.Logf(" [ ] routeAntigravityTest 选择了正确的路径") - t.Logf(" [ ] testAntigravityAccountConnection 被调用") - t.Logf(" [ ] AntigravityGatewayService.TestConnection 被调用") - t.Logf("") - t.Logf("3. 捕获详细错误信息:") - t.Logf(" [ ] 错误的完整字符串(不仅仅是 'IT')") - t.Logf(" [ ] 错误的类型(type)") - t.Logf(" [ ] 错误发生的确切代码行") - t.Logf(" [ ] HTTP 状态码(如有)") - t.Logf(" [ ] HTTP 响应体(如有)") - t.Logf("") - t.Logf("4. 验证 SSE 流处理:") - t.Logf(" [ ] 错误事件的 type 字段") - t.Logf(" [ ] 错误事件的 error 字段内容") - t.Logf(" [ ] 是否有 UTF-8 编码问题") - t.Logf("") - }) - - // 建议的实际代码改进 - t.Run("SuggestedCodeFixes", func(t *testing.T) { - t.Logf("🔧 建议的代码改进:") - t.Logf("") - t.Logf("1. 在 testAntigravityAccountConnection 中增加日志:") - t.Logf(" ```go") - t.Logf(" result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID)") - t.Logf(" if err != nil {") - t.Logf(" log.Printf(\"[ERROR] TestConnection failed: type=%%T, error=%%v, msg='%%s'\", err, err, err.Error())") - t.Logf(" return s.sendErrorAndEnd(c, err.Error())") - t.Logf(" }") - t.Logf(" ```") - t.Logf("") - t.Logf("2. 在 sendErrorAndEnd 中增加详细日志:") - t.Logf(" ```go") - t.Logf(" func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, msg string) error {") - t.Logf(" log.Printf(\"[SEND_ERROR] msg='%%s' (len=%%d, bytes=%%v)\", msg, len(msg), []byte(msg))") - t.Logf(" s.sendEvent(c, TestEvent{Type: \"test_error\", Error: msg, Success: false})") - t.Logf(" return nil") - t.Logf(" }") - t.Logf(" ```") - t.Logf("") - t.Logf("3. 检查 TestConnection 中的错误处理:") - t.Logf(" 在 antigravity_gateway_service.go 的 TestConnection 函数中") - t.Logf(" 追踪每个错误返回点的错误信息") - t.Logf("") - }) - - // 最后的总结 - t.Log("") - t.Log("📊 测试摘要:") - t.Log("✅ 账号凭证验证: 通过") - t.Log("✅ 路由逻辑验证: 通过") - t.Log("⚠️ 实际错误诊断: 需要在完整环境中运行") - t.Log("") - t.Log("下一步:") - t.Log("1. 添加建议的代码日志") - t.Log("2. 重新运行 HTTP 测试") - t.Log("3. 收集完整的错误信息") - t.Log("4. 分析并修复根本原因") -} diff --git a/backend/internal/service/antigravity_test_http_flow_test.go b/backend/internal/service/antigravity_test_http_flow_test.go deleted file mode 100644 index 4d71dd7a..00000000 --- a/backend/internal/service/antigravity_test_http_flow_test.go +++ /dev/null @@ -1,188 +0,0 @@ -package service - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" -) - -// TestHTTPResponseFlow 测试完整的 HTTP 请求-响应流,看客户端会收到什么 -func TestHTTPResponseFlow(t *testing.T) { - t.Log("🔥 模拟完整的 HTTP 请求-响应流...") - t.Log("") - - // 创建一个模拟的服务 - gin.SetMode(gin.TestMode) - router := gin.New() - - // 模拟账号测试端点 - router.POST("/api/v1/admin/accounts/:id/test", func(c *gin.Context) { - // 模拟返回错误的情况 - - // 设置 SSE 头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - c.Status(http.StatusOK) - - // 发送测试开始事件 - event1 := map[string]interface{}{ - "type": "test_start", - "model": "claude-opus-4-6", - } - jsonData1, _ := json.Marshal(event1) - c.Writer.WriteString("data: " + string(jsonData1) + "\n\n") - c.Writer.Flush() - - // 模拟一个错误:比如 "INVALID_TOKEN" 或其他上游错误 - // 这里我们故意测试不同的错误信息来看 curl 会显示什么 - - errorMessages := []string{ - "INVALID_TOKEN", - "INTERNAL_ERROR", - "Invalid authentication credentials", - "Th", // 测试短错误 - "IT", // 直接测试 "IT" - } - - selectedError := errorMessages[3] // 选择第 4 个:这应该显示为 "Th" 而不是 "IT" - - event2 := map[string]interface{}{ - "type": "error", - "error": selectedError, - "success": false, - } - jsonData2, _ := json.Marshal(event2) - c.Writer.WriteString("data: " + string(jsonData2) + "\n\n") - c.Writer.Flush() - - // 发送完成事件 - event3 := map[string]interface{}{ - "type": "test_complete", - "success": false, - } - jsonData3, _ := json.Marshal(event3) - c.Writer.WriteString("data: " + string(jsonData3) + "\n\n") - c.Writer.Flush() - - t.Logf("📤 服务器发送的错误: '%s'", selectedError) - }) - - // 测试 1: 发送 HTTP 请求 - t.Run("SendRequestAndCheckResponse", func(t *testing.T) { - t.Log("步骤 1: 发送 HTTP 请求...") - - req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test", - bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6"}`))) - req.Header.Set("Content-Type", "application/json") - - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - t.Log("✅ 请求已发送") - t.Log("") - - // 步骤 2: 检查响应 - t.Log("步骤 2: 分析 HTTP 响应...") - t.Logf(" HTTP Status: %d", w.Code) - t.Logf(" Content-Type: %s", w.Header().Get("Content-Type")) - t.Log("") - - // 步骤 3: 读取 SSE 响应 - t.Log("步骤 3: 读取 SSE 事件...") - body := w.Body.String() - t.Logf(" 响应总长度: %d 字节", len(body)) - t.Log("") - - // 解析 SSE 事件 - lines := bytes.Split([]byte(body), []byte("\n\n")) - for i, line := range lines { - if len(line) == 0 { - continue - } - - // 去掉 "data: " 前缀 - if bytes.HasPrefix(line, []byte("data: ")) { - data := bytes.TrimPrefix(line, []byte("data: ")) - - var event map[string]interface{} - err := json.Unmarshal(data, &event) - if err != nil { - t.Logf(" 事件 %d: [解析失败] %v", i, err) - continue - } - - t.Logf(" 事件 %d:", i) - t.Logf(" type: %v", event["type"]) - - if errMsg, ok := event["error"]; ok { - t.Logf(" error: %v (长度: %d)", errMsg, len(errMsg.(string))) - - // 这就是 curl 会看到的错误信息 - errStr := errMsg.(string) - if errStr == "IT" { - t.Logf(" ✓ 发现 'IT' 错误!") - } else if errStr == "Th" { - t.Logf(" ℹ️ 这是 'Th' 而不是 'IT'") - } else { - t.Logf(" ℹ️ 实际错误: '%s'", errStr) - } - } - - if model, ok := event["model"]; ok { - t.Logf(" model: %v", model) - } - } - } - - t.Log("") - t.Log("📋 完整的原始响应:") - t.Logf("%s", body) - }) - - // 测试 2: 模拟真实的 curl 请求 - t.Run("SimulateRealCurlRequest", func(t *testing.T) { - t.Log("步骤: 模拟真实 curl 命令...") - t.Log("") - - // 发送请求 - req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test", - bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6","prompt":""}`))) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-token") - - w := httptest.NewRecorder() - router.ServeHTTP(w, req) - - // 模拟 curl 读取响应 - body := w.Body.String() - - t.Log("curl 会看到:") - t.Log("```") - t.Log(body) - t.Log("```") - }) -} - -// 辅助函数:提取 SSE 事件中的错误信息 -func extractErrorFromSSE(sseBody string) string { - lines := bytes.Split([]byte(sseBody), []byte("\n\n")) - for _, line := range lines { - if bytes.HasPrefix(line, []byte("data: ")) { - data := bytes.TrimPrefix(line, []byte("data: ")) - var event map[string]interface{} - if err := json.Unmarshal(data, &event); err != nil { - continue - } - if errMsg, ok := event["error"]; ok { - return errMsg.(string) - } - } - } - return "" -} diff --git a/backend/internal/service/antigravity_test_singleton_test.go b/backend/internal/service/antigravity_test_singleton_test.go deleted file mode 100644 index 21ebd87d..00000000 --- a/backend/internal/service/antigravity_test_singleton_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package service - -import ( - "encoding/json" - "strconv" - "testing" - "time" -) - -// TestAntigravityCredentialsValidation 单例测试:验证给定的 Antigravity 账号凭证有效性 -// 本测试使用服务器的真实代码函数,不依赖 HTTP 层,模拟云端场景 -func TestAntigravityCredentialsValidation(t *testing.T) { - // 测试数据:来自你提供的账号信息 - // ID: 68, 平台: antigravity, 类型: oauth - proxyID := int64(9) - testAccount := &Account{ - ID: 68, - Name: "PriesJosephe139@gmail.com", - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213", - "refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo", - "email": "priesjosephe139@gmail.com", - "expires_at": "1775903154", - "project_id": "kinetic-sum-r3tp7", - "plan_type": "Free", - }, - ProxyID: &proxyID, - Concurrency: 100, - } - - // 测试 1: 验证账号凭证完整性 - t.Run("ValidateAccountCredentials", func(t *testing.T) { - if testAccount.ID == 0 { - t.Fatal("Account ID is missing") - } - if testAccount.Platform != PlatformAntigravity { - t.Fatalf("Expected platform %s, got %s", PlatformAntigravity, testAccount.Platform) - } - if testAccount.Type != AccountTypeOAuth { - t.Fatalf("Expected type %s, got %s", AccountTypeOAuth, testAccount.Type) - } - - // 验证必要的凭证字段 - accessToken := testAccount.GetCredential("access_token") - if accessToken == "" { - t.Fatal("Access token is missing") - } - refreshToken := testAccount.GetCredential("refresh_token") - if refreshToken == "" { - t.Fatal("Refresh token is missing") - } - projectID := testAccount.GetCredential("project_id") - if projectID == "" { - t.Fatal("Project ID is missing") - } - - t.Log("✅ 账号凭证完整性验证通过") - t.Logf(" Account ID: %d, Email: %s, ProjectID: %s", testAccount.ID, testAccount.GetCredential("email"), projectID) - }) - - // 测试 2: 测试 token 映射和模型验证 - t.Run("ValidateModelMapping", func(t *testing.T) { - testModels := []string{ - "claude-opus-4-6", - "claude-sonnet-4-6", - "gemini-3-pro-preview", - } - - for _, model := range testModels { - t.Logf("✓ Model %s is supported for account", model) - } - - t.Log("✅ 模型映射验证通过") - }) - - // 测试 3: 构建测试请求(不实际发送,只验证格式) - t.Run("BuildTestRequest", func(t *testing.T) { - projectID := testAccount.GetCredential("project_id") - if projectID == "" { - t.Skip("Project ID not available, skipping request building") - } - - // 构建 Claude 测试请求的简化版本 - claudeReq := map[string]any{ - "model": "claude-opus-4-6", - "messages": []map[string]any{ - { - "role": "user", - "content": []map[string]any{ - { - "type": "text", - "text": ".", - }, - }, - }, - }, - "max_tokens": 1, - "stream": true, - } - - requestBody, err := json.Marshal(claudeReq) - if err != nil { - t.Fatalf("Failed to marshal request: %v", err) - } - - t.Logf("✅ 请求体构建成功,大小: %d bytes", len(requestBody)) - if len(requestBody) > 200 { - t.Logf(" 请求格式: %s...", string(requestBody[:200])) - } else { - t.Logf(" 请求格式: %s", string(requestBody)) - } - }) - - // 测试 4: 验证 Token 信息格式 - t.Run("ValidateTokenInfo", func(t *testing.T) { - expiresAtStr := testAccount.GetCredential("expires_at") - if expiresAtStr == "" { - t.Log("⚠️ No expires_at timestamp found") - return - } - - // 尝试解析时间戳 - expiresAtUnix, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err == nil { - expiresAt := time.Unix(expiresAtUnix, 0) - now := time.Now() - if expiresAt.After(now) { - remainingTime := expiresAt.Sub(now) - t.Logf("✅ Token 有效期检查通过") - t.Logf(" 过期时间: %s (还有 %v)", expiresAt.Format("2006-01-02 15:04:05 MST"), remainingTime) - } else { - t.Logf("⚠️ Token 已过期: %s", expiresAt.Format("2006-01-02 15:04:05 MST")) - t.Log(" 预期行为: 应该刷新 refresh_token") - } - } - }) - - // 测试 5: 创建 Antigravity 客户端并验证连接(如果可行) - t.Run("InitializeAntigravityClient", func(t *testing.T) { - // 使用账号的代理信息初始化客户端 - if testAccount.ProxyID != nil { - t.Logf("Account uses proxy ID: %d", *testAccount.ProxyID) - } - - t.Log("📌 Antigravity 客户端初始化代码路径:") - t.Log(" 1. 使用 accessToken 创建 antigravity.NewClient(proxyURL)") - t.Log(" 2. 调用 client.LoadCodeAssist(ctx, accessToken) 验证凭证") - t.Log(" 3. 检查响应中的 CloudAICompanionProject 字段") - t.Log("") - t.Log(" 预期行为:") - t.Log(" ✓ projectID == 'kinetic-sum-r3tp7'") - t.Log(" ✓ statusCode 200") - t.Log(" ✓ 无错误返回") - }) - - // 测试 6: 验证账号支持的操作 - t.Run("VerifyAccountOperations", func(t *testing.T) { - operations := []string{ - "GetAccessToken", - "RefreshToken", - "LoadCodeAssist", - "GetUserInfo", - "SetPrivacy", - } - - for _, op := range operations { - t.Logf("✓ Operation supported: %s", op) - } - - t.Log("") - t.Log("✅ 账号支持的操作列表验证通过") - }) - - // 测试 7: 文档化测试流程(实际调用时的步骤) - t.Run("DocumentTestFlow", func(t *testing.T) { - t.Log("📝 本地测试 Antigravity 账号的完整流程:") - t.Log("") - t.Log("步骤 1: 初始化服务") - t.Log(" - accountRepo: 从数据库获取账号") - t.Log(" - tokenProvider: Antigravity Token 提供者") - t.Log(" - httpUpstream: HTTP 请求执行器") - t.Log(" - gatewayService: Antigravity 网关服务") - t.Log("") - t.Log("步骤 2: 验证账号凭证") - t.Log(" account := accountRepo.GetByID(ctx, 68)") - t.Log(" accessToken := account.GetCredential('access_token')") - t.Log(" projectID := account.GetCredential('project_id')") - t.Log("") - t.Log("步骤 3: 构建测试请求") - t.Log(" requestBody := gatewayService.buildClaudeTestRequest(projectID, 'claude-opus-4-6')") - t.Log("") - t.Log("步骤 4: 执行请求") - t.Log(" result := gatewayService.TestConnection(ctx, account, 'claude-opus-4-6')") - t.Log("") - t.Log("步骤 5: 处理结果") - t.Log(" if err != nil {") - t.Log(" // 记录错误详情") - t.Log(" }") - t.Log("") - t.Log("⚠️ 当前问题:返回了 'IT' 错误") - t.Log(" 这可能表示:") - t.Log(" 1. 错误消息被截断或编码错误") - t.Log(" 2. HTTP 响应体包含不完整的错误文本") - t.Log(" 3. 上游 API 返回的错误被不正确地处理") - }) - - t.Log("") - t.Log("✅ 所有本地验证测试完成!") - t.Log("") - t.Log("下一步:在实际环境中运行完整测试") -} diff --git a/backend/internal/service/antigravity_test_socks5_proxy_test.go b/backend/internal/service/antigravity_test_socks5_proxy_test.go deleted file mode 100644 index 9ddbec50..00000000 --- a/backend/internal/service/antigravity_test_socks5_proxy_test.go +++ /dev/null @@ -1,194 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/url" - "testing" - "time" - - "golang.org/x/net/proxy" -) - -// TestWithSOCKS5Proxy 使用指定的 SOCKS5 代理调用上游 API -func TestWithSOCKS5Proxy(t *testing.T) { - t.Log("🔥 使用 SOCKS5 代理调用 Google API...") - t.Log("") - - // SOCKS5 代理配置 - proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760" - accessToken := "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213" - - t.Log("📌 代理信息:") - t.Logf(" 代理地址: %s", proxyAddr) - t.Logf(" 访问令牌: %s... (长度: %d)", accessToken[:30], len(accessToken)) - t.Log("") - - // 创建上下文和超时 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // 步骤 1: 设置 SOCKS5 代理 - t.Run("SetupSOCKS5Proxy", func(t *testing.T) { - t.Log("步骤 1: 配置 SOCKS5 代理...") - - // 解析代理 URL - proxyURL, err := url.Parse(proxyAddr) - if err != nil { - t.Fatalf("❌ 解析代理 URL 失败: %v", err) - } - t.Logf(" ✓ 代理 URL 解析成功") - t.Logf(" Scheme: %s", proxyURL.Scheme) - t.Logf(" Host: %s", proxyURL.Host) - t.Logf(" User: %s", proxyURL.User.Username()) - t.Log("") - - // 创建代理拨号器 - dialer, err := proxy.FromURL(proxyURL, proxy.Direct) - if err != nil { - t.Fatalf("❌ 创建代理拨号器失败: %v", err) - } - t.Log(" ✓ 代理拨号器创建成功") - t.Log("") - - // 创建自定义传输 - transport := &http.Transport{ - Dial: dialer.Dial, - } - - // 创建自定义 HTTP 客户端 - httpClient := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, - } - - t.Log(" ✓ HTTP 客户端创建成功") - t.Log("") - - // 步骤 2: 测试代理连接 - t.Log("步骤 2: 测试代理连接...") - - // 尝试一个简单的 HTTP 请求来测试代理 - req, err := http.NewRequestWithContext(ctx, "GET", "https://www.google.com", nil) - if err != nil { - t.Logf("❌ 创建测试请求失败: %v", err) - return - } - - resp, err := httpClient.Do(req) - if err != nil { - t.Logf("❌ 通过代理访问 Google 失败: %v", err) - t.Log(" (这可能表示代理配置或网络连接有问题)") - return - } - defer resp.Body.Close() - - t.Logf(" ✓ 代理连接成功!") - t.Logf(" HTTP Status: %d", resp.StatusCode) - t.Log("") - }) - - // 步骤 3: 使用代理调用 Antigravity API - t.Run("CallAntigravityWithProxy", func(t *testing.T) { - t.Log("步骤 3: 通过代理调用 Antigravity API...") - t.Log("") - - // 解析代理 URL - proxyURL, err := url.Parse(proxyAddr) - if err != nil { - t.Fatalf("❌ 解析代理 URL 失败: %v", err) - } - - // 创建代理拨号器 - dialer, err := proxy.FromURL(proxyURL, proxy.Direct) - if err != nil { - t.Fatalf("❌ 创建代理拨号器失败: %v", err) - } - - // 创建自定义传输 - transport := &http.Transport{ - Dial: dialer.Dial, - } - - // 这里我们需要修改 antigravity.Client 来使用自定义的 HTTP 客户端 - // 但由于 antigravity.NewClient 可能不支持自定义客户端, - // 我们直接创建一个 HTTP 客户端来调用 API - - httpClient := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, - } - - t.Log(" 正在调用 Google Cloud Code API...") - t.Log("") - - // 直接构造 API 请求 - apiURL := "https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" - - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, nil) - if err != nil { - t.Fatalf("❌ 创建请求失败: %v", err) - } - - // 添加认证头 - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "Antigravity Client") - - t.Logf(" 📤 请求信息:") - t.Logf(" URL: %s", apiURL) - t.Logf(" Method: POST") - t.Logf(" Auth: Bearer %s...", accessToken[:30]) - t.Log("") - - // 发送请求 - t.Log(" ⏳ 正在等待响应...") - resp, err := httpClient.Do(req) - if err != nil { - t.Logf("❌ API 调用失败:") - t.Logf(" 错误类型: %T", err) - t.Logf(" 错误信息: %v", err) - t.Logf(" 错误字符串: %s", err.Error()) - t.Log("") - - // 分析错误 - errStr := err.Error() - if len(errStr) >= 2 { - t.Logf("📊 错误的前 5 个字符: '%s'", errStr[:min(5, len(errStr))]) - if errStr[:2] == "IT" { - t.Logf(" ✓ 找到了! 这就是 'IT' 错误的来源!") - } - } - return - } - defer resp.Body.Close() - - t.Logf("✅ API 调用成功!") - t.Logf(" HTTP Status: %d", resp.StatusCode) - t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type")) - t.Log("") - - // 读取响应体 - respBody, err := io.ReadAll(resp.Body) - if err != nil { - t.Logf("❌ 读取响应失败: %v", err) - return - } - - t.Log("📋 API 响应:") - if resp.StatusCode == 200 { - var result map[string]interface{} - if err := json.Unmarshal(respBody, &result); err == nil { - jsonBytes, _ := json.MarshalIndent(result, " ", " ") - t.Logf(" %s", string(jsonBytes)) - } else { - t.Logf(" %s", string(respBody)) - } - } else { - t.Logf(" 状态码: %d", resp.StatusCode) - t.Logf(" 错误响应: %s", string(respBody)) - } - }) -} diff --git a/backend/internal/service/antigravity_token_provider_requestpath_test.go b/backend/internal/service/antigravity_token_provider_requestpath_test.go deleted file mode 100644 index 3a430175..00000000 --- a/backend/internal/service/antigravity_token_provider_requestpath_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package service - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestShouldMarkTempUnschedulableForRefreshError(t *testing.T) { - t.Run("skip global oauth client secret missing", func(t *testing.T) { - err := errors.New(`token 刷新失败 (重试后): error: code=400 reason="ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING" message="missing antigravity oauth client_secret; set ANTIGRAVITY_OAUTH_CLIENT_SECRET" metadata=map[]`) - require.False(t, shouldMarkTempUnschedulableForRefreshError(err)) - }) - - t.Run("allow account specific refresh error", func(t *testing.T) { - err := errors.New("token 刷新失败 (重试后): invalid_grant") - require.True(t, shouldMarkTempUnschedulableForRefreshError(err)) - }) -} diff --git a/backend/internal/service/antigravity_warmup.go b/backend/internal/service/antigravity_warmup.go deleted file mode 100644 index 6da9f7e3..00000000 --- a/backend/internal/service/antigravity_warmup.go +++ /dev/null @@ -1,83 +0,0 @@ -package service - -import ( - "context" - "log/slog" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" -) - -// WarmupAntigravityAccount 预热新的 Antigravity 账号 -// 在账号创建后立即调用,避免首次请求的 503 延迟 -// -// 预热流程: -// 1. GetUserInfo - 验证 token 有效性 -// 2. LoadCodeAssist - 初始化项目信息 -// 3. FetchAvailableModels - 初始化模型列表 -// -// 总耗时通常 4-6 秒,预热期间的失败不影响账号创建结果(非阻塞) -func (s *AntigravityOAuthService) WarmupAntigravityAccount(ctx context.Context, accessToken, projectID, proxyURL string) { - logger := slog.Default() - - // 5 秒超时预热(防止卡住其他操作) - warmupCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - client, err := antigravity.NewClient(proxyURL) - if err != nil { - logger.Warn("antigravity_warmup_client_creation_failed", "error", err) - return - } - - start := time.Now() - defer func() { - elapsed := time.Since(start) - logger.Info("antigravity_account_warmup_completed", "elapsed_ms", elapsed.Milliseconds()) - }() - - // Step 1: 验证 token - _, err = client.GetUserInfo(warmupCtx, accessToken) - if err != nil { - logger.Warn("antigravity_warmup_get_user_info_failed", "error", err) - // 继续后续步骤(部分失败不中止) - } - - // Step 2: 初始化项目信息 - _, _, err = client.LoadCodeAssist(warmupCtx, accessToken) - if err != nil { - logger.Warn("antigravity_warmup_load_code_assist_failed", "error", err) - } - - // Step 3: 初始化模型列表 - if projectID != "" { - _, _, err := client.FetchAvailableModels(warmupCtx, accessToken, projectID) - if err != nil { - logger.Warn("antigravity_warmup_fetch_available_models_failed", "error", err) - } - } -} - -// WarmupOptions 预热选项 -type WarmupOptions struct { - // Async 为 true 时在后台预热(推荐) - Async bool - // Timeout 单次预热操作的超时时间 - Timeout time.Duration -} - -// WarmupAntigravityAccountAsync 异步预热账号(推荐用法) -func (s *AntigravityOAuthService) WarmupAntigravityAccountAsync(ctx context.Context, accessToken, projectID, proxyURL string, opts *WarmupOptions) { - if opts == nil { - opts = &WarmupOptions{ - Async: true, - Timeout: 5 * time.Second, - } - } - - if opts.Async { - go s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL) - } else { - s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL) - } -} diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go index e5bf7568..06b4f3ab 100644 --- a/backend/internal/service/channel_monitor_template_types.go +++ b/backend/internal/service/channel_monitor_template_types.go @@ -1,8 +1,9 @@ package service import ( - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // ChannelMonitorRequestTemplate 请求模板(service 层模型)。 diff --git a/backend/internal/service/gateway_attribution.go b/backend/internal/service/gateway_attribution.go deleted file mode 100644 index 9e2d99bc..00000000 --- a/backend/internal/service/gateway_attribution.go +++ /dev/null @@ -1,284 +0,0 @@ -package service - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "regexp" - "strings" - "sync" - "time" - - claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude" - "github.com/google/uuid" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" -) - -// Attribution block constants matching real Claude Code 2.1.89. -// Source: src/constants/system.ts + src/utils/fingerprint.ts -const ( - // fingerprintSalt must match the hardcoded salt in the real CLI. - // Source: extracted/src/utils/fingerprint.ts:8 - fingerprintSalt = "59cf53e54c78" -) - -type attributionBlockOptions struct { - Entrypoint string - Workload string - OmitCCH bool -} - -// computeAttributionFingerprint computes a 3-character hex fingerprint -// matching the algorithm in the real Claude Code CLI. -// -// Algorithm: SHA256(SALT + msg[4] + msg[7] + msg[20] + version)[:3] -// Source: extracted/src/utils/fingerprint.ts:50-63 -func computeAttributionFingerprint(firstUserMessageText, cliVersion string) string { - indices := [3]int{4, 7, 20} - chars := make([]byte, 0, 3) - for _, i := range indices { - if i < len(firstUserMessageText) { - chars = append(chars, firstUserMessageText[i]) - } else { - chars = append(chars, '0') - } - } - - input := fmt.Sprintf("%s%s%s", fingerprintSalt, string(chars), cliVersion) - hash := sha256.Sum256([]byte(input)) - return hex.EncodeToString(hash[:])[:3] -} - -// extractFirstUserMessageText extracts text from the first user message in the body. -// Handles both string content and array content (text blocks). -func extractFirstUserMessageText(body []byte) string { - messages := gjson.GetBytes(body, "messages") - if !messages.Exists() || !messages.IsArray() { - return "" - } - - var firstText string - messages.ForEach(func(_, msg gjson.Result) bool { - if msg.Get("role").String() != "user" { - return true // continue - } - content := msg.Get("content") - if content.Type == gjson.String { - firstText = content.String() - return false // break - } - if content.IsArray() { - content.ForEach(func(_, block gjson.Result) bool { - if block.Get("type").String() == "text" { - firstText = block.Get("text").String() - return false - } - return true - }) - return false - } - return true - }) - return firstText -} - -// buildAttributionBlock builds the x-anthropic-billing-header attribution string -// that real Claude Code injects as the first system text block. -// -// Format: x-anthropic-billing-header: cc_version=.; cc_entrypoint=cli; cch=00000; -// Source: extracted/src/constants/system.ts:73-95 -func buildAttributionBlock(cliVersion, fingerprint string, opts attributionBlockOptions) string { - if claude.AttributionHeaderDisabled() { - return "" - } - version := cliVersion + "." + fingerprint - entrypoint := strings.TrimSpace(opts.Entrypoint) - if entrypoint == "" { - entrypoint = claude.CurrentEntrypoint() - } - workload := strings.TrimSpace(opts.Workload) - if workload == "" { - workload = claude.CurrentWorkload() - } - - var b strings.Builder - b.Grow(96) - fmt.Fprintf(&b, "x-anthropic-billing-header: cc_version=%s; cc_entrypoint=%s;", version, entrypoint) - if !opts.OmitCCH { - // 2.1.89+ 的 Claude Code 在 1P / standard providers 下保留 cch=00000 占位符, - // 由下游 attestation / signing 逻辑在需要时替换。 - b.WriteString(" cch=00000;") - } - if workload != "" { - fmt.Fprintf(&b, " cc_workload=%s;", workload) - } - return b.String() -} - -// injectAttributionBlock prepends the x-anthropic-billing-header attribution block -// as the very first system text block in the request body. -// This must come BEFORE the "You are Claude Code" block. -// -// The real CLI injects this as system[0] with cache_control: {type: "ephemeral"}. -func injectAttributionBlock(body []byte, cliVersion string, opts attributionBlockOptions) []byte { - // Compute fingerprint from the first user message - firstMsgText := extractFirstUserMessageText(body) - fingerprint := computeAttributionFingerprint(firstMsgText, cliVersion) - attribution := buildAttributionBlock(cliVersion, fingerprint, opts) - if attribution == "" { - return body - } - - // Build the attribution text block as JSON - attrBlock, err := marshalAnthropicSystemTextBlock(attribution, true) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to build attribution block: %v", err) - return body - } - - systemResult := gjson.GetBytes(body, "system") - - // Handle the different system formats - switch { - case !systemResult.Exists() || systemResult.Type == gjson.Null: - // No system field — inject just the attribution block - newBody, err := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock})) - if err != nil { - return body - } - return newBody - - case systemResult.Type == gjson.String: - // String system — convert to array: [attribution, original] - origBlock, err := marshalAnthropicSystemTextBlock(systemResult.String(), false) - if err != nil { - return body - } - newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock, origBlock})) - if setErr != nil { - return body - } - return newBody - - case systemResult.IsArray(): - // Array system — check if attribution already exists, prepend if not - var items [][]byte - alreadyHasAttribution := false - systemResult.ForEach(func(_, item gjson.Result) bool { - if item.Get("type").String() == "text" { - text := item.Get("text").String() - if len(text) > 30 && text[:30] == "x-anthropic-billing-header: cc" { - alreadyHasAttribution = true - } - } - return true - }) - if alreadyHasAttribution { - return body - } - - items = append(items, attrBlock) - systemResult.ForEach(func(_, item gjson.Result) bool { - items = append(items, []byte(item.Raw)) - return true - }) - newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw(items)) - if setErr != nil { - return body - } - return newBody - - default: - return body - } -} - -// cliSessionEntry holds a cached session UUID with an expiration time. -type cliSessionEntry struct { - id string - expiresAt time.Time -} - -// cliSessionCache stores per-account session UUIDs that rotate on a TTL. -// Real CLI creates a new random UUID per process invocation; we approximate -// this by rotating every 30-60 minutes (jittered per account). -var ( - cliSessionCache = make(map[int64]cliSessionEntry) - cliSessionCacheMu sync.Mutex -) - -// sessionTTLBase is the base TTL for session ID rotation. -const sessionTTLBase = 30 * time.Minute - -// generateSessionIDForAccount returns a per-account session UUID that rotates -// periodically. Each account gets a random TTL jitter (0-30 min on top of -// the 30 min base) so accounts don't all rotate simultaneously. -func generateSessionIDForAccount(instanceSalt string, accountID int64) string { - cliSessionCacheMu.Lock() - defer cliSessionCacheMu.Unlock() - - now := time.Now() - if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) { - return entry.id - } - - // Compute per-account jitter from a hash so the same account always gets - // the same jitter within a process (avoids re-rolling on every rotation). - jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID) - h := sha256.Sum256([]byte(jitterSeed)) - jitterMinutes := int(h[0]) % 31 // 0-30 minutes - ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute - - newID := uuid.New().String() - cliSessionCache[accountID] = cliSessionEntry{ - id: newID, - expiresAt: now.Add(ttl), - } - return newID -} - -// reUserHome matches /Users// or /home// path segments. -// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username. -var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`) - -// reEnvLine matches lines of the form "Key: value" for the environment block -// fields injected by Claude Code's CLAUDE.md / sysprompt machinery. -var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`) - -// canonicalEnvValues maps environment block keys to their canonical replacements. -// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine. -var canonicalEnvValues = map[string]string{ - "Platform": "Platform: darwin", - "Shell": "Shell: zsh", - "OS Version": "OS Version: Darwin 24.4.0", - "Working directory": "Working directory: /Users/user/project", -} - -// NormalizeSystemPromptEnv rewrites environment-specific fields in a system -// prompt text block to canonical values, preventing real machine fingerprinting. -// -// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText): -// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0 -// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/" -// -// Only called on system prompt text blocks, never on user message content. -func NormalizeSystemPromptEnv(text string) string { - // Replace env-info lines with canonical values - text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string { - for key, canonical := range canonicalEnvValues { - if len(line) >= len(key) && line[:len(key)] == key { - return canonical - } - } - return line - }) - - // Redact real usernames in home directory paths - // e.g. /Users/alice/project -> /Users/user/project - text = reUserHome.ReplaceAllString(text, "${1}user/") - - return text -} diff --git a/backend/internal/service/gateway_attribution_test.go b/backend/internal/service/gateway_attribution_test.go deleted file mode 100644 index f946c1c3..00000000 --- a/backend/internal/service/gateway_attribution_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package service - -import ( - "net/http" - "testing" -) - -func TestApplyClaudeRuntimeOptionalHeaders(t *testing.T) { - t.Setenv("CLAUDE_CODE_CONTAINER_ID", "ctr-123") - t.Setenv("CLAUDE_CODE_REMOTE_SESSION_ID", "remote-456") - t.Setenv("CLAUDE_AGENT_SDK_CLIENT_APP", "desktop") - t.Setenv("CLAUDE_CODE_ADDITIONAL_PROTECTION", "true") - - req, err := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil) - if err != nil { - t.Fatalf("NewRequest() error = %v", err) - } - - applyClaudeRuntimeOptionalHeaders(req) - - if got := getHeaderRaw(req.Header, "x-claude-remote-container-id"); got != "ctr-123" { - t.Fatalf("x-claude-remote-container-id = %q", got) - } - if got := getHeaderRaw(req.Header, "x-claude-remote-session-id"); got != "remote-456" { - t.Fatalf("x-claude-remote-session-id = %q", got) - } - if got := getHeaderRaw(req.Header, "x-client-app"); got != "desktop" { - t.Fatalf("x-client-app = %q", got) - } - if got := getHeaderRaw(req.Header, "x-anthropic-additional-protection"); got != "true" { - t.Fatalf("x-anthropic-additional-protection = %q", got) - } -} - -func TestBuildAttributionBlock_UsesEntrypointAndWorkload(t *testing.T) { - t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "") - - got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{ - Entrypoint: "sdk-cli", - Workload: "cron", - }) - want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=sdk-cli; cch=00000; cc_workload=cron;" - if got != want { - t.Fatalf("buildAttributionBlock() = %q, want %q", got, want) - } -} - -func TestBuildAttributionBlock_OmitsCCHForBedrockLikeProviders(t *testing.T) { - t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "") - - got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{ - Entrypoint: "cli", - OmitCCH: true, - }) - want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=cli;" - if got != want { - t.Fatalf("buildAttributionBlock() = %q, want %q", got, want) - } -} - -func TestInjectAttributionBlock_DisabledByEnv(t *testing.T) { - t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "false") - - body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) - got := injectAttributionBlock(body, "2.1.104", attributionBlockOptions{}) - if string(got) != string(body) { - t.Fatalf("injectAttributionBlock() should keep body unchanged when attribution disabled") - } -} - -func TestShouldOmitAttributionCCH(t *testing.T) { - if !shouldOmitAttributionCCH(&Account{Type: AccountTypeBedrock}, "") { - t.Fatal("expected bedrock account to omit cch") - } - if !shouldOmitAttributionCCH(&Account{Extra: map[string]any{"provider": "mantle"}}, "") { - t.Fatal("expected mantle provider to omit cch") - } - if shouldOmitAttributionCCH(&Account{Type: AccountTypeOAuth}, "oauth") { - t.Fatal("expected oauth account to keep cch") - } -} diff --git a/backend/internal/service/gateway_claude_runtime_headers.go b/backend/internal/service/gateway_claude_runtime_headers.go deleted file mode 100644 index d8d3dd9a..00000000 --- a/backend/internal/service/gateway_claude_runtime_headers.go +++ /dev/null @@ -1,47 +0,0 @@ -package service - -import ( - "net/http" - "strings" - - claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude" -) - -func applyClaudeRuntimeOptionalHeaders(req *http.Request) { - if req == nil { - return - } - for key, value := range claude.OptionalAPIHeaders() { - if strings.TrimSpace(value) == "" { - continue - } - setHeaderRaw(req.Header, resolveWireCasing(key), value) - } -} - -func attributionOptionsForRequest(account *Account, tokenType string) attributionBlockOptions { - return attributionBlockOptions{ - Entrypoint: claude.CurrentEntrypoint(), - Workload: claude.CurrentWorkload(), - OmitCCH: shouldOmitAttributionCCH(account, tokenType), - } -} - -func shouldOmitAttributionCCH(account *Account, tokenType string) bool { - if strings.EqualFold(strings.TrimSpace(tokenType), "bedrock") { - return true - } - if account == nil { - return false - } - if account.Type == AccountTypeBedrock { - return true - } - for _, key := range []string{"provider", "upstream_provider"} { - switch strings.ToLower(strings.TrimSpace(account.GetExtraString(key))) { - case "bedrock", "anthropicaws", "anthropic_aws", "mantle": - return true - } - } - return false -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f6185ff2..bd4e18cf 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -901,9 +901,6 @@ func sanitizeSystemText(text string) string { "You are OpenCode, the best coding agent on the planet.", strings.TrimSpace(claudeCodeSystemPrompt), ) - // Normalize environment block fields (Platform/Shell/OS Version/Working directory) - // to canonical values so different client machines don't create fingerprint divergence. - text = NormalizeSystemPromptEnv(text) return text } @@ -4230,22 +4227,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } } - // 注入 x-anthropic-billing-header attribution block(所有 OAuth 账号) - // 真实 CLI 在 system prompt 的第一个 text block 注入此 billing header。 - // 用于 Anthropic 后端路由和验证。 - // 跳过条件:system 已被 rewriteSystemForNonClaudeCode 重写(claudeCodeSystemPrompt 在 system[0]); - // 注入会将其移到 system[1],破坏伪装结构及 system[0] 断言。 - if account.IsOAuth() && !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemRewritten { - // 获取 CLI 版本:优先用指纹中的版本,回退到默认 - attrCLIVersion := claude.DefaultCLIVersion - if fp := getHeaderRaw(c.Request.Header, "User-Agent"); fp != "" { - if v := ExtractCLIVersion(fp); v != "" { - attrCLIVersion = v - } - } - body = injectAttributionBlock(body, attrCLIVersion, attributionOptionsForRequest(account, "oauth")) - } - // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) @@ -5926,35 +5907,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id // 2. 客户端未提供(mimic 模式)→ 生成确定性 per-account session UUID // 真实 CLI 每个请求都携带此 header(per-process UUID)。 + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { if parsed := ParseMetadataUserID(uid); parsed != nil { setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) } } - } else if tokenType == "oauth" { - // mimic 模式:生成 session-id - var sessionID string - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - sessionID = parsed.SessionID - } - } - if sessionID == "" { - salt := "" - if s.cfg != nil { - salt = s.cfg.Gateway.InstanceSalt - } - sessionID = generateSessionIDForAccount(salt, account.ID) - } - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID) } // x-client-request-id: 真实 CLI 每个请求生成新 UUID(仅 1P)。 if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" { setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String()) } - applyClaudeRuntimeOptionalHeaders(req) // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ @@ -8984,35 +8949,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // X-Claude-Code-Session-Id 头处理(count_tokens 路径) + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { if parsed := ParseMetadataUserID(uid); parsed != nil { setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) } } - } else if tokenType == "oauth" { - var sessionID string - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - sessionID = parsed.SessionID - } - } - if sessionID == "" { - salt := "" - if s.cfg != nil { - salt = s.cfg.Gateway.InstanceSalt - } - sessionID = generateSessionIDForAccount(salt, account.ID) - } - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID) } // x-client-request-id(count_tokens 路径) if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" { setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String()) } - applyClaudeRuntimeOptionalHeaders(req) if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 52b89fc8..45a627a1 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -13,7 +13,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/tidwall/gjson" "github.com/tidwall/sjson" diff --git a/backend/internal/service/identity_service_antigravity.go b/backend/internal/service/identity_service_antigravity.go deleted file mode 100644 index 01077331..00000000 --- a/backend/internal/service/identity_service_antigravity.go +++ /dev/null @@ -1,28 +0,0 @@ -package service - -import "github.com/Wei-Shaw/sub2api/internal/pkg/claude" - -// ============================================================== -// antigravity — identity_service 扩展 -// -// 此文件包含 Antigravity fork 对 IdentityService 的扩展, -// 新增了实例级隔离盐值和指纹默认值覆盖功能。 -// -// 对上游文件 identity_service.go 的最小化改动: -// - defaultFingerprint 版本号更新 -// - IdentityService struct 新增 instanceSalt 字段 -// ============================================================== - -// ApplyDefaultFingerprintOverrides 用配置覆盖 identity_service 的默认指纹 -// 允许不同部署实例设置不同的 CLI/SDK 版本号,避免所有实例指纹相同 -func ApplyDefaultFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) { - claude.ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch) - defaultFingerprint = defaultIdentityFingerprint() -} - -// NewIdentityServiceWithSalt 创建带实例盐值的 IdentityService -// 实例盐值用于 user_id 重写时的 session hash 混淆, -// 使不同 sub2api 实例对相同输入产生不同的 hash 输出,增加隔离性 -func NewIdentityServiceWithSalt(cache IdentityCache, salt string) *IdentityService { - return &IdentityService{cache: cache, instanceSalt: salt} -} diff --git a/backend/internal/service/language_server_service.go b/backend/internal/service/language_server_service.go deleted file mode 100644 index 986c2574..00000000 --- a/backend/internal/service/language_server_service.go +++ /dev/null @@ -1,530 +0,0 @@ -package service - -import ( - "bufio" - "context" - "encoding/json" - "fmt" - "io" - "log/slog" - "strings" - "sync" - "time" - - "github.com/google/uuid" -) - -// CascadeSession 代表一个 Cascade Agent 会话 -type CascadeSession struct { - ID string - ModelName string - Messages []map[string]interface{} // {role, content} - Metadata map[string]string // 设备指纹、User-Agent 等 - Token string // OAuth token - CreatedAt int64 -} - -// LanguageServerService 业务逻辑层 -// 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API -type LanguageServerService struct { - // 会话管理 - cascadeSessions map[string]*CascadeSession - sessionMutex sync.RWMutex - - // 上游 HTTP 服务(用于发送请求) - httpUpstream HTTPUpstream - - // Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新) - antigravitySvc *AntigravityGatewayService - accountRepo AccountRepository - - // 日志 - logger *slog.Logger - - // 改进 1: 速率限制 (令牌桶) - // 限制并发消息处理数量,保护上游 API - rateLimiter chan struct{} - - // 改进 3: 会话过期时间 (秒) - sessionTTLSeconds int64 - - // 改进 3: 定期清理后台任务 - cleanupTicker *time.Ticker - stopCleanup chan struct{} -} - -func NewLanguageServerService( - logger *slog.Logger, - httpUpstream HTTPUpstream, - antigravitySvc *AntigravityGatewayService, - accountRepo AccountRepository, -) *LanguageServerService { - svc := &LanguageServerService{ - cascadeSessions: make(map[string]*CascadeSession), - logger: logger, - httpUpstream: httpUpstream, - antigravitySvc: antigravitySvc, - accountRepo: accountRepo, - rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息 - sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期 - stopCleanup: make(chan struct{}), - } - - // 改进 3: 启动后台清理任务 - svc.startSessionCleanup() - - return svc -} - -// startSessionCleanup 启动会话定期清理任务 -func (svc *LanguageServerService) startSessionCleanup() { - svc.cleanupTicker = time.NewTicker(1 * time.Minute) - - go func() { - for { - select { - case <-svc.cleanupTicker.C: - svc.cleanupExpiredSessions() - case <-svc.stopCleanup: - svc.cleanupTicker.Stop() - return - } - } - }() -} - -// cleanupExpiredSessions 清理过期的会话 -func (svc *LanguageServerService) cleanupExpiredSessions() { - now := getCurrentTimeMS() - ttlMs := svc.sessionTTLSeconds * 1000 - - svc.sessionMutex.Lock() - defer svc.sessionMutex.Unlock() - - deletedCount := 0 - for id, session := range svc.cascadeSessions { - if now-session.CreatedAt > ttlMs { - delete(svc.cascadeSessions, id) - deletedCount++ - } - } - - if deletedCount > 0 { - svc.logger.Info("expired sessions cleaned up", - "deleted_count", deletedCount, - "remaining_sessions", len(svc.cascadeSessions), - ) - } -} - -// Stop 优雅关闭服务 -func (svc *LanguageServerService) Stop() { - select { - case svc.stopCleanup <- struct{}{}: - default: - } -} - -// SetSessionTTL sets the session TTL for testing purposes -func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) { - svc.sessionTTLSeconds = ttlSeconds -} - -// GetCascadeSessions returns the current cascade sessions map (for testing) -func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession { - svc.sessionMutex.RLock() - defer svc.sessionMutex.RUnlock() - return svc.cascadeSessions -} - -// ============================================================================ -// Cascade 业务逻辑 -// ============================================================================ - -// StartCascade 启动新的 Cascade Agent 会话 -func (svc *LanguageServerService) StartCascade( - ctx context.Context, - model string, - systemPrompt string, - metadata map[string]string, - token string, -) (string, error) { - // 1. 验证输入 - if model == "" { - return "", fmt.Errorf("model is required") - } - - if token == "" { - return "", fmt.Errorf("oauth token is required") - } - - // 2. 生成会话 ID - sessionID := uuid.New().String() - - // 3. 创建会话 - session := &CascadeSession{ - ID: sessionID, - ModelName: model, - Messages: make([]map[string]interface{}, 0), - Metadata: metadata, - Token: token, - CreatedAt: getCurrentTimeMS(), - } - - // 如果提供了系统提示,添加为初始消息 - if systemPrompt != "" { - session.Messages = append(session.Messages, map[string]interface{}{ - "role": "user", - "content": systemPrompt, - }) - } - - // 4. 保存会话 - svc.sessionMutex.Lock() - svc.cascadeSessions[sessionID] = session - svc.sessionMutex.Unlock() - - svc.logger.Info("cascade session started", - "session_id", sessionID, - "model", model, - "has_system_prompt", systemPrompt != "") - - return sessionID, nil -} - -// SendUserMessage 发送用户消息到 Cascade -// 返回流式更新通道 -func (svc *LanguageServerService) SendUserMessage( - ctx context.Context, - cascadeID string, - userMessage string, - token string, -) (<-chan interface{}, error) { - // 改进 1: 获取速率限制令牌 - select { - case svc.rateLimiter <- struct{}{}: - // 获得令牌,继续 - case <-ctx.Done(): - return nil, fmt.Errorf("context cancelled") - default: - // 没有令牌,需要等待 - select { - case svc.rateLimiter <- struct{}{}: - // 获得令牌 - case <-ctx.Done(): - return nil, fmt.Errorf("context cancelled while waiting for rate limit") - case <-time.After(30 * time.Second): - return nil, fmt.Errorf("rate limit timeout: too many concurrent messages") - } - } - - // 1. 获取会话 - svc.sessionMutex.RLock() - session, exists := svc.cascadeSessions[cascadeID] - svc.sessionMutex.RUnlock() - - if !exists { - // 释放令牌 - <-svc.rateLimiter - return nil, fmt.Errorf("cascade session not found: %s", cascadeID) - } - - // 2. 验证 token - if token != session.Token { - // 释放令牌 - <-svc.rateLimiter - return nil, fmt.Errorf("invalid token for session") - } - - // 改进 2: 并发安全的消息追加(深拷贝消息列表) - svc.sessionMutex.Lock() - newMessages := make([]map[string]interface{}, len(session.Messages)+1) - copy(newMessages, session.Messages) - newMessages[len(newMessages)-1] = map[string]interface{}{ - "role": "user", - "content": userMessage, - } - session.Messages = newMessages - svc.sessionMutex.Unlock() - - // 4. 创建响应通道 - updateChan := make(chan interface{}, 100) - - // 5. 启动后台 goroutine 处理 API 调用 - go func() { - defer func() { - // 关闭通道 - close(updateChan) - // 改进 1: 释放速率限制令牌 - <-svc.rateLimiter - }() - - // 调用上游 API(关键!这里需要伪装) - svc.callUpstreamAPI(ctx, session, updateChan) - }() - - svc.logger.Info("user message sent to cascade", - "session_id", cascadeID, - "message_length", len(userMessage), - "concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数 - ) - - return updateChan, nil -} - -// CancelCascade 取消 Cascade 会话 -func (svc *LanguageServerService) CancelCascade( - ctx context.Context, - cascadeID string, -) error { - svc.sessionMutex.Lock() - _, exists := svc.cascadeSessions[cascadeID] - svc.sessionMutex.Unlock() - - if !exists { - return fmt.Errorf("cascade session not found: %s", cascadeID) - } - - // TODO: 取消正在进行的 API 调用 - - svc.logger.Info("cascade cancelled", "session_id", cascadeID) - return nil -} - -// ============================================================================ -// 模型配置 -// ============================================================================ - -// ModelConfig 模型配置 -type ModelConfig struct { - Name string - DisplayName string - MaxTokens int - SupportsThinking bool - ThinkingBudget int - SupportsImages bool - Provider string -} - -// GetAvailableModels 获取可用模型列表 -func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) { - models := []ModelConfig{ - { - Name: "claude-opus-4-7", - DisplayName: "Claude Opus 4.7", - MaxTokens: 200000, - SupportsThinking: true, - ThinkingBudget: 32000, - SupportsImages: true, - Provider: "anthropic", - }, - { - Name: "claude-sonnet-4-7", - DisplayName: "Claude Sonnet 4.7", - MaxTokens: 200000, - SupportsThinking: true, - ThinkingBudget: 16000, - SupportsImages: true, - Provider: "anthropic", - }, - { - Name: "claude-opus-4-6", - DisplayName: "Claude Opus 4.6", - MaxTokens: 200000, - SupportsThinking: true, - ThinkingBudget: 32000, - SupportsImages: true, - Provider: "anthropic", - }, - { - Name: "claude-sonnet-4-6", - DisplayName: "Claude Sonnet 4.6", - MaxTokens: 200000, - SupportsThinking: false, - SupportsImages: true, - Provider: "anthropic", - }, - { - Name: "claude-haiku-4-5", - DisplayName: "Claude Haiku 4.5", - MaxTokens: 200000, - SupportsThinking: false, - SupportsImages: true, - Provider: "anthropic", - }, - { - Name: "gemini-3-pro", - DisplayName: "Gemini 3 Pro", - MaxTokens: 128000, - SupportsThinking: false, - SupportsImages: true, - Provider: "google", - }, - } - - return models, nil -} - -// ============================================================================ -// 状态查询 -// ============================================================================ - -// GetStatus 获取服务状态 -func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) { - // TODO: 检查上游 API 连接状态 - return "running", nil -} - -// ============================================================================ -// 内部方法 -// ============================================================================ - -// callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。 -// 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。 -func (svc *LanguageServerService) callUpstreamAPI( - ctx context.Context, - session *CascadeSession, - updateChan chan<- interface{}, -) { - if svc.antigravitySvc == nil || svc.accountRepo == nil { - updateChan <- map[string]interface{}{ - "type": "error", - "error": "antigravity gateway not configured", - } - return - } - - // 1. 选取第一个可用的 Antigravity 账号 - accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity) - if err != nil || len(accounts) == 0 { - svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err) - updateChan <- map[string]interface{}{ - "type": "error", - "error": "no antigravity accounts available", - } - return - } - account := &accounts[0] - - // 2. 准备 Claude 格式请求体 - requestBody := map[string]interface{}{ - "model": session.ModelName, - "messages": session.Messages, - "stream": true, - "max_tokens": 8192, - } - bodyJSON, err := json.Marshal(requestBody) - if err != nil { - svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err) - updateChan <- map[string]interface{}{ - "type": "error", - "error": "failed to prepare request", - } - return - } - - svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID) - - // 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试) - respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON) - if err != nil { - svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err) - updateChan <- map[string]interface{}{ - "type": "error", - "error": fmt.Sprintf("upstream request failed: %v", err), - } - return - } - defer func() { _ = respBody.Close() }() - - // 4. 处理错误响应 - if statusCode >= 400 { - body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20)) - svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body)) - updateChan <- map[string]interface{}{ - "type": "error", - "status_code": statusCode, - "error": string(body), - } - return - } - - // 5. 流式转发响应 - svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan) -} - -// streamUpstreamResponse 处理上游 SSE 流式响应 -func (svc *LanguageServerService) streamUpstreamResponse( - ctx context.Context, - sessionID string, - body io.ReadCloser, - updateChan chan<- interface{}, -) { - scanner := bufio.NewScanner(body) - // 设置合理的缓冲区大小 - scanner.Buffer(make([]byte, 64*1024), 512*1024) - - for scanner.Scan() { - select { - case <-ctx.Done(): - svc.logger.Info("streaming cancelled", "session_id", sessionID) - return - default: - } - - line := strings.TrimSpace(scanner.Text()) - - // 跳过空行 - if line == "" { - continue - } - - // 跳过注释行 - if strings.HasPrefix(line, ":") { - continue - } - - // 解析 SSE 格式 (data: {...}) - if !strings.HasPrefix(line, "data:") { - continue - } - - eventData := strings.TrimPrefix(line, "data:") - eventData = strings.TrimSpace(eventData) - - // 解析 JSON - var event map[string]interface{} - if err := json.Unmarshal([]byte(eventData), &event); err != nil { - svc.logger.Debug("failed to parse event", - "session_id", sessionID, - "error", err, - "data", eventData, - ) - continue - } - - // 发送事件到客户端通道 - select { - case updateChan <- event: - case <-ctx.Done(): - return - case <-time.After(5 * time.Second): - svc.logger.Warn("channel send timeout", - "session_id", sessionID, - ) - return - } - } - - if err := scanner.Err(); err != nil { - svc.logger.Error("scanning upstream response failed", - "session_id", sessionID, - "error", err, - ) - } -} - -// getCurrentTimeMS 获取当前时间戳(毫秒) -func getCurrentTimeMS() int64 { - return time.Now().UnixMilli() -} diff --git a/backend/internal/service/lsrpc_handler.go b/backend/internal/service/lsrpc_handler.go deleted file mode 100644 index 29cfd92d..00000000 --- a/backend/internal/service/lsrpc_handler.go +++ /dev/null @@ -1,353 +0,0 @@ -package service - -import ( - "context" - "fmt" - "io/fs" - "log/slog" - "net/http" - "os" - "path/filepath" - "time" - - connect "connectrpc.com/connect" - "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb" - "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect" - "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "google.golang.org/protobuf/types/known/timestamppb" -) - -const upstreamLSRPCBaseURL = "https://cloudcode-pa.googleapis.com" - -// LSRPCHandler implements LanguageServerServiceHandler by proxying to the real upstream -// lsrpc service using OAuth tokens obtained from AntigravityGatewayService. -// File RPCs (ReadFile/WriteFile/ReadDir/etc.) operate on the local filesystem. -type LSRPCHandler struct { - language_server_pbconnect.UnimplementedLanguageServerServiceHandler - - antigravitySvc *AntigravityGatewayService - accountRepo AccountRepository - logger *slog.Logger -} - -// NewLSRPCHandler creates a new LSRPCHandler. -func NewLSRPCHandler( - antigravitySvc *AntigravityGatewayService, - accountRepo AccountRepository, - logger *slog.Logger, -) *LSRPCHandler { - if logger == nil { - logger = slog.Default() - } - return &LSRPCHandler{ - antigravitySvc: antigravitySvc, - accountRepo: accountRepo, - logger: logger, - } -} - -// upstreamClient creates a connectrpc client to the real lsrpc upstream, -// authenticated with the OAuth token from the given account. -func (h *LSRPCHandler) upstreamClient(ctx context.Context) (language_server_pbconnect.LanguageServerServiceClient, error) { - accounts, err := h.accountRepo.ListByPlatform(ctx, PlatformAntigravity) - if err != nil || len(accounts) == 0 { - return nil, fmt.Errorf("no antigravity accounts available: %w", err) - } - account := &accounts[0] - - tokenProvider := h.antigravitySvc.GetTokenProvider() - if tokenProvider == nil { - return nil, fmt.Errorf("antigravity token provider not configured") - } - accessToken, err := tokenProvider.GetAccessToken(ctx, account) - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - httpClient := &http.Client{ - Timeout: 5 * time.Minute, - Transport: &bearerTransport{ - base: http.DefaultTransport, - token: accessToken, - }, - } - - client := language_server_pbconnect.NewLanguageServerServiceClient( - httpClient, - upstreamLSRPCBaseURL, - connect.WithGRPC(), - ) - return client, nil -} - -// bearerTransport injects Authorization: Bearer into every request. -type bearerTransport struct { - base http.RoundTripper - token string -} - -func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - clone := req.Clone(req.Context()) - clone.Header.Set("Authorization", "Bearer "+t.token) - return t.base.RoundTrip(clone) -} - -// ============================================================================ -// Cascade RPCs — proxied to real upstream -// ============================================================================ - -func (h *LSRPCHandler) StartCascade( - ctx context.Context, - req *connect.Request[language_server_pb.StartCascadeRequest], -) (*connect.Response[language_server_pb.StartCascadeResponse], error) { - client, err := h.upstreamClient(ctx) - if err != nil { - return nil, connect.NewError(connect.CodeUnavailable, err) - } - return client.StartCascade(ctx, req) -} - -func (h *LSRPCHandler) SendUserCascadeMessage( - ctx context.Context, - req *connect.Request[language_server_pb.SendUserCascadeMessageRequest], - stream *connect.ServerStream[language_server_pb.CascadeReactiveUpdate], -) error { - client, err := h.upstreamClient(ctx) - if err != nil { - return connect.NewError(connect.CodeUnavailable, err) - } - - upstreamStream, err := client.SendUserCascadeMessage(ctx, req) - if err != nil { - return err - } - defer upstreamStream.Close() - - for upstreamStream.Receive() { - if err := stream.Send(upstreamStream.Msg()); err != nil { - return err - } - } - return upstreamStream.Err() -} - -func (h *LSRPCHandler) CancelCascadeInvocation( - ctx context.Context, - req *connect.Request[language_server_pb.CancelCascadeInvocationRequest], -) (*connect.Response[language_server_pb.CancelCascadeInvocationResponse], error) { - client, err := h.upstreamClient(ctx) - if err != nil { - return nil, connect.NewError(connect.CodeUnavailable, err) - } - return client.CancelCascadeInvocation(ctx, req) -} - -func (h *LSRPCHandler) AcknowledgeCascadeCodeEdit( - ctx context.Context, - req *connect.Request[language_server_pb.AcknowledgeCascadeCodeEditRequest], -) (*connect.Response[language_server_pb.AcknowledgeCascadeCodeEditResponse], error) { - client, err := h.upstreamClient(ctx) - if err != nil { - return nil, connect.NewError(connect.CodeUnavailable, err) - } - return client.AcknowledgeCascadeCodeEdit(ctx, req) -} - -// ============================================================================ -// Model config RPCs — proxied to real upstream -// ============================================================================ - -func (h *LSRPCHandler) GetCascadeModelConfigs( - ctx context.Context, - req *connect.Request[language_server_pb.GetCascadeModelConfigsRequest], -) (*connect.Response[language_server_pb.GetCascadeModelConfigsResponse], error) { - client, err := h.upstreamClient(ctx) - if err != nil { - // Fall back to static list when upstream unavailable. - return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{ - Models: staticCascadeModels(), - }), nil - } - resp, err := client.GetCascadeModelConfigs(ctx, req) - if err != nil { - return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{ - Models: staticCascadeModels(), - }), nil - } - return resp, nil -} - -func (h *LSRPCHandler) GetCommandModelConfigs( - ctx context.Context, - req *connect.Request[language_server_pb.GetCommandModelConfigsRequest], -) (*connect.Response[language_server_pb.GetCommandModelConfigsResponse], error) { - client, err := h.upstreamClient(ctx) - if err != nil { - return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{ - Models: staticCascadeModels(), - }), nil - } - resp, err := client.GetCommandModelConfigs(ctx, req) - if err != nil { - return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{ - Models: staticCascadeModels(), - }), nil - } - return resp, nil -} - -// staticCascadeModels returns a hard-coded model list as fallback. -func staticCascadeModels() []*language_server_pb.ModelConfig { - return []*language_server_pb.ModelConfig{ - {Name: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"}, - {Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"}, - {Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"}, - {Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"}, - } -} - -// ============================================================================ -// File RPCs — local filesystem implementation -// ============================================================================ - -func (h *LSRPCHandler) ReadFile( - ctx context.Context, - req *connect.Request[language_server_pb.ReadFileRequest], -) (*connect.Response[language_server_pb.ReadFileResponse], error) { - path := req.Msg.GetPath() - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %s", path)) - } - return nil, connect.NewError(connect.CodeInternal, err) - } - return connect.NewResponse(&language_server_pb.ReadFileResponse{ - Content: string(data), - }), nil -} - -func (h *LSRPCHandler) WriteFile( - ctx context.Context, - req *connect.Request[language_server_pb.WriteFileRequest], -) (*connect.Response[language_server_pb.WriteFileResponse], error) { - path := req.Msg.GetPath() - if req.Msg.GetCreateParent() { - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return nil, connect.NewError(connect.CodeInternal, err) - } - } - if err := os.WriteFile(path, []byte(req.Msg.GetContent()), 0o644); err != nil { - return nil, connect.NewError(connect.CodeInternal, err) - } - return connect.NewResponse(&language_server_pb.WriteFileResponse{}), nil -} - -func (h *LSRPCHandler) ReadDir( - ctx context.Context, - req *connect.Request[language_server_pb.ReadDirRequest], -) (*connect.Response[language_server_pb.ReadDirResponse], error) { - path := req.Msg.GetPath() - entries, err := os.ReadDir(path) - if err != nil { - if os.IsNotExist(err) { - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %s", path)) - } - return nil, connect.NewError(connect.CodeInternal, err) - } - - files := make([]*language_server_pb.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - continue - } - files = append(files, fileInfoFromOS(entry.Name(), info)) - } - return connect.NewResponse(&language_server_pb.ReadDirResponse{ - Files: files, - }), nil -} - -func (h *LSRPCHandler) DeleteFileOrDirectory( - ctx context.Context, - req *connect.Request[language_server_pb.DeleteFileOrDirectoryRequest], -) (*connect.Response[language_server_pb.DeleteFileOrDirectoryResponse], error) { - path := req.Msg.GetPath() - if err := os.RemoveAll(path); err != nil { - return nil, connect.NewError(connect.CodeInternal, err) - } - return connect.NewResponse(&language_server_pb.DeleteFileOrDirectoryResponse{}), nil -} - -func (h *LSRPCHandler) StatUri( - ctx context.Context, - req *connect.Request[language_server_pb.StatUriRequest], -) (*connect.Response[language_server_pb.StatUriResponse], error) { - path := req.Msg.GetPath() - info, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %s", path)) - } - return nil, connect.NewError(connect.CodeInternal, err) - } - return connect.NewResponse(&language_server_pb.StatUriResponse{ - FileInfo: fileInfoFromOS(info.Name(), info), - }), nil -} - -func (h *LSRPCHandler) WatchDirectory( - ctx context.Context, - req *connect.Request[language_server_pb.WatchDirectoryRequest], - stream *connect.ServerStream[language_server_pb.WatchDirectoryResponse], -) error { - // Block until context is cancelled — real FS watching requires fsnotify which - // is not in the dependency graph yet. This satisfies the interface contract - // without crashing; the client will get an EOF when the connection closes. - <-ctx.Done() - return nil -} - -// ============================================================================ -// Health RPCs -// ============================================================================ - -func (h *LSRPCHandler) Heartbeat( - ctx context.Context, - req *connect.Request[language_server_pb.HeartbeatRequest], -) (*connect.Response[language_server_pb.HeartbeatResponse], error) { - return connect.NewResponse(&language_server_pb.HeartbeatResponse{ - Healthy: true, - Version: "sub2api", - }), nil -} - -func (h *LSRPCHandler) GetStatus( - ctx context.Context, - req *connect.Request[language_server_pb.GetStatusRequest], -) (*connect.Response[language_server_pb.GetStatusResponse], error) { - return connect.NewResponse(&language_server_pb.GetStatusResponse{ - Status: "running", - Version: antigravity.BaseURL, - }), nil -} - -// ============================================================================ -// Helpers -// ============================================================================ - -func fileInfoFromOS(name string, info fs.FileInfo) *language_server_pb.FileInfo { - t := language_server_pb.FileInfo_FILE - if info.IsDir() { - t = language_server_pb.FileInfo_DIRECTORY - } else if info.Mode()&os.ModeSymlink != 0 { - t = language_server_pb.FileInfo_SYMLINK - } - return &language_server_pb.FileInfo{ - Path: name, - Type: t, - Size: info.Size(), - ModifiedTime: timestamppb.New(info.ModTime()), - } -} diff --git a/backend/internal/service/openai_messages_dispatch_test.go b/backend/internal/service/openai_messages_dispatch_test.go index a625aadd..7e3350f4 100644 --- a/backend/internal/service/openai_messages_dispatch_test.go +++ b/backend/internal/service/openai_messages_dispatch_test.go @@ -1,8 +1,10 @@ package service -import "testing" +import ( + "testing" -import "github.com/stretchr/testify/require" + "github.com/stretchr/testify/require" +) func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7279e6a..193c0430 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -8,8 +8,6 @@ import ( "encoding/base64" "encoding/hex" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "image" "image/color" stddraw "image/draw" @@ -24,6 +22,9 @@ import ( "sync" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + xdraw "golang.org/x/image/draw" "golang.org/x/sync/singleflight" ) diff --git a/backend/internal/service/windsurf_chat_service_test.go b/backend/internal/service/windsurf_chat_service_test.go index 80d7536e..53984a75 100644 --- a/backend/internal/service/windsurf_chat_service_test.go +++ b/backend/internal/service/windsurf_chat_service_test.go @@ -80,10 +80,10 @@ func TestInjectModelIdentity(t *testing.T) { wantInjected bool }{ { - name: "anthropic model without system", - messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}}, - meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"}, - modelKey: "claude-sonnet-4.6", + name: "anthropic model without system", + messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}}, + meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"}, + modelKey: "claude-sonnet-4.6", wantInjected: true, }, { @@ -111,10 +111,10 @@ func TestInjectModelIdentity(t *testing.T) { wantInjected: false, }, { - name: "openai model without system", - messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}}, - meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"}, - modelKey: "gpt-4o", + name: "openai model without system", + messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}}, + meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"}, + modelKey: "gpt-4o", wantInjected: true, }, } diff --git a/backend/internal/service/windsurf_gateway_service.go b/backend/internal/service/windsurf_gateway_service.go index 3a6f2a5e..8626e42d 100644 --- a/backend/internal/service/windsurf_gateway_service.go +++ b/backend/internal/service/windsurf_gateway_service.go @@ -609,13 +609,13 @@ type windsurfRequestTool struct { // ---- Helper functions (prefixed to avoid collision with windsurf_gateway_handler.go) ---- type windsurfContentBlock struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input interface{} `json:"input,omitempty"` - ToolUseID string `json:"tool_use_id,omitempty"` - Content json.RawMessage `json:"content,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input interface{} `json:"input,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content json.RawMessage `json:"content,omitempty"` // Source 来自 Anthropic image block:{type:"base64", media_type:"image/png", data:"..."} Source *windsurfContentImageSource `json:"source,omitempty"` } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 8af5c693..ed9241e1 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -494,7 +494,6 @@ var ProviderSet = wire.NewSet( NewPaymentService, ProvidePaymentOrderExpiryService, ProvideBalanceNotifyService, - ProvideLanguageServerService, ProvideWindsurfAuthService, ProvideWindsurfLSService, ProvideWindsurfChatService, @@ -507,11 +506,6 @@ var ProviderSet = wire.NewSet( NewChannelMonitorRequestTemplateService, ) -// ProvideLanguageServerService creates LanguageServerService with injected dependencies -func ProvideLanguageServerService(httpUpstream HTTPUpstream, antigravitySvc *AntigravityGatewayService, accountRepo AccountRepository) *LanguageServerService { - return NewLanguageServerService(slog.Default(), httpUpstream, antigravitySvc, accountRepo) -} - // ProvideWindsurfAuthService creates WindsurfAuthService from the main config. func ProvideWindsurfAuthService(cfg *config.Config, accountRepo AccountRepository, proxyRepo ProxyRepository, adminSvc AdminService) *WindsurfAuthService { if !cfg.Windsurf.Enabled {