chore: 删除 Antigravity 订制代码,回退至上游 v0.1.118
Some checks failed
CI / test (push) Failing after 3s
CI / frontend (push) Failing after 4s
CI / golangci-lint (push) Failing after 6s
CI / windsurf-platform (macos-latest) (push) Has been cancelled
CI / windsurf-platform (windows-latest) (push) Has been cancelled
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 3s
Some checks failed
CI / test (push) Failing after 3s
CI / frontend (push) Failing after 4s
CI / golangci-lint (push) Failing after 6s
CI / windsurf-platform (macos-latest) (push) Has been cancelled
CI / windsurf-platform (windows-latest) (push) Has been cancelled
Security Scan / backend-security (push) Failing after 3s
Security Scan / frontend-security (push) Failing after 3s
- 删除自定义文件:gateway_attribution, gateway_claude_runtime_headers, identity_service_antigravity, language_server_service, lsrpc_handler, antigravity_http handler/routes, 所有 antigravity 专项测试 - 将 antigravity pkg/service 文件回退至上游版本(移除 IsEnterprise、 claude_code_tool_map、dynamic fingerprint 等定制逻辑) - 修复 gateway_service.go:移除 NormalizeSystemPromptEnv、 generateSessionIDForAccount、applyClaudeRuntimeOptionalHeaders 调用, 使用上游的 session-id 同步逻辑 - 恢复 language_server_pb gen 文件(Windsurf local_ls.go 依赖) - 保留全部 Windsurf 集成代码不变
This commit is contained in:
parent
2064c1a19f
commit
898a65314c
@ -8,6 +8,11 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
@ -18,14 +23,9 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
)
|
||||
|
||||
@ -257,9 +257,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
langServerService := service.ProvideLanguageServerService(httpUpstream, antigravityGatewayService, accountRepository)
|
||||
lsrpcHandler := service.NewLSRPCHandler(antigravityGatewayService, accountRepository, nil)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, langServerService, lsrpcHandler)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||
@ -271,7 +269,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
<<<<<<< HEAD
|
||||
channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, windsurfRefreshService, channelMonitorRunner, windsurfLSService)
|
||||
application := &Application{
|
||||
|
||||
@ -1,114 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
func repeatStr(s string, count int) string {
|
||||
return strings.Repeat(s, count)
|
||||
}
|
||||
|
||||
func main() {
|
||||
accessToken := flag.String("token", "", "OAuth access token")
|
||||
projectID := flag.String("project", "", "Project ID")
|
||||
proxyURL := flag.String("proxy", "", "Proxy URL (optional)")
|
||||
flag.Parse()
|
||||
|
||||
if *accessToken == "" || *projectID == "" {
|
||||
log.Fatal("missing required flags: -token and -project")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := antigravity.NewClient(*proxyURL)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println(repeatStr("=", 80))
|
||||
fmt.Println("Antigravity Privacy Setup Diagnostic Test")
|
||||
fmt.Println(repeatStr("=", 80))
|
||||
|
||||
// Step 1: Verify token is valid by fetching user info
|
||||
fmt.Println("\n[Step 1] Verifying access token...")
|
||||
userInfo, err := client.GetUserInfo(ctx, *accessToken)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get user info: %v", err)
|
||||
}
|
||||
fmt.Printf("✓ Email: %s\n", userInfo.Email)
|
||||
|
||||
// Step 2: Call SetUserSettings
|
||||
fmt.Println("\n[Step 2] Calling SetUserSettings (clear privacy settings)...")
|
||||
setResp, err := client.SetUserSettings(ctx, *accessToken)
|
||||
if err != nil {
|
||||
log.Fatalf("SetUserSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if setResp.IsSuccess() {
|
||||
fmt.Println("✓ SetUserSettings succeeded")
|
||||
fmt.Printf(" Response: %+v\n", setResp)
|
||||
} else {
|
||||
fmt.Println("✗ SetUserSettings returned non-empty userSettings")
|
||||
fmt.Printf(" Response: %+v\n", setResp)
|
||||
fmt.Println("\n ERROR: This indicates privacy settings were NOT cleared!")
|
||||
fmt.Println(" Possible causes:")
|
||||
fmt.Println(" 1. Account restrictions on privacy settings")
|
||||
fmt.Println(" 2. Account still has telemetryEnabled=true")
|
||||
fmt.Println(" 3. API response indicates settings persist")
|
||||
}
|
||||
|
||||
// Step 3: Verify by calling FetchUserInfo
|
||||
fmt.Println("\n[Step 3] Calling FetchUserInfo to verify privacy status...")
|
||||
userInfoResp, err := client.FetchUserInfo(ctx, *accessToken, *projectID)
|
||||
if err != nil {
|
||||
log.Fatalf("FetchUserInfo failed: %v", err)
|
||||
}
|
||||
|
||||
if userInfoResp.IsPrivate() {
|
||||
fmt.Println("✓ Privacy is properly set (userSettings is empty)")
|
||||
fmt.Printf(" Response: %+v\n", userInfoResp)
|
||||
} else {
|
||||
fmt.Println("✗ Privacy is NOT properly set (userSettings contains telemetryEnabled)")
|
||||
fmt.Printf(" Response: %+v\n", userInfoResp)
|
||||
fmt.Println("\n ERROR: This explains the 503 errors in gateway!")
|
||||
fmt.Println(" Reason: Antigravity API rejects requests from accounts with")
|
||||
fmt.Println(" telemetryEnabled=true to protect user privacy")
|
||||
}
|
||||
|
||||
// Summary
|
||||
fmt.Println("\n" + repeatStr("=", 80))
|
||||
fmt.Println("DIAGNOSIS SUMMARY")
|
||||
fmt.Println(repeatStr("=", 80))
|
||||
|
||||
if setResp.IsSuccess() && userInfoResp.IsPrivate() {
|
||||
fmt.Println("✓ Privacy setup is SUCCESSFUL")
|
||||
fmt.Println(" This account should NOT experience 503 errors due to privacy")
|
||||
fmt.Println(" The 503 errors might be due to:")
|
||||
fmt.Println(" 1. Temporary API outages")
|
||||
fmt.Println(" 2. Rate limiting on new accounts")
|
||||
fmt.Println(" 3. Other infrastructure issues")
|
||||
} else if !setResp.IsSuccess() && !userInfoResp.IsPrivate() {
|
||||
fmt.Println("✗ Privacy setup FAILED")
|
||||
fmt.Println(" The account cannot clear privacy settings on Antigravity")
|
||||
fmt.Println(" This causes the 503 Service Unavailable errors")
|
||||
fmt.Println("\nSOLUTION:")
|
||||
fmt.Println(" 1. Check if this is a restricted account type")
|
||||
fmt.Println(" 2. Try re-authorizing the account")
|
||||
fmt.Println(" 3. Check Antigravity API rate limiting")
|
||||
fmt.Println(" 4. Inspect firewall/proxy settings")
|
||||
} else {
|
||||
fmt.Println("⚠ INCONSISTENT STATE:")
|
||||
fmt.Println(" SetUserSettings and FetchUserInfo returned different results")
|
||||
fmt.Println(" This might indicate a transient API issue or data sync delay")
|
||||
}
|
||||
|
||||
fmt.Println("\n" + repeatStr("=", 80))
|
||||
}
|
||||
@ -1,316 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// TestScenario 定义一个测试场景
|
||||
type TestScenario struct {
|
||||
name string
|
||||
description string
|
||||
testFunc func(ctx context.Context, token, projectID string) (bool, string)
|
||||
}
|
||||
|
||||
var scenarios []TestScenario
|
||||
|
||||
func init() {
|
||||
scenarios = []TestScenario{
|
||||
{
|
||||
name: "single_request",
|
||||
description: "单次请求 - 检查是否立即成功",
|
||||
testFunc: testSingleRequest,
|
||||
},
|
||||
{
|
||||
name: "sequential_requests",
|
||||
description: "顺序发送 10 个请求 - 找到稳定点",
|
||||
testFunc: testSequentialRequests,
|
||||
},
|
||||
{
|
||||
name: "concurrent_requests",
|
||||
description: "并发发送 5 个请求 - 检查并发初始化行为",
|
||||
testFunc: testConcurrentRequests,
|
||||
},
|
||||
{
|
||||
name: "warmup_then_request",
|
||||
description: "预热(模型列表请求) + 业务请求 - 验证预热效果",
|
||||
testFunc: testWarmupThenRequest,
|
||||
},
|
||||
{
|
||||
name: "delayed_request",
|
||||
description: "延迟 5 秒后请求 - 检查账号初始化时间",
|
||||
testFunc: testDelayedRequest,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// testSingleRequest 单次请求
|
||||
func testSingleRequest(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("请求失败 (%v): %v", elapsed, err)
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return false, fmt.Sprintf("响应为空 (%v)", elapsed)
|
||||
}
|
||||
|
||||
return true, fmt.Sprintf("✓ 单次请求成功 - 耗时 %v", elapsed)
|
||||
}
|
||||
|
||||
// testSequentialRequests 顺序发送多个请求
|
||||
func testSequentialRequests(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
var firstFailIdx = -1
|
||||
var firstSuccessIdx = -1
|
||||
var timings []time.Duration
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
start := time.Now()
|
||||
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
|
||||
elapsed := time.Since(start)
|
||||
timings = append(timings, elapsed)
|
||||
|
||||
success := err == nil && resp != nil
|
||||
fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", i+1, elapsed, map[bool]string{true: "✓", false: "✗"}[success])
|
||||
|
||||
if !success && firstFailIdx == -1 {
|
||||
firstFailIdx = i
|
||||
}
|
||||
if success && firstSuccessIdx == -1 {
|
||||
firstSuccessIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
var report string
|
||||
if firstSuccessIdx == -1 {
|
||||
report = "✗ 全部失败"
|
||||
} else if firstSuccessIdx == 0 {
|
||||
report = fmt.Sprintf("✓ 首次即成功 (耗时 %v)", timings[0])
|
||||
} else {
|
||||
report = fmt.Sprintf("⚠ 第 %d 次才成功 (失败 %d 次), 首次耗时 %v",
|
||||
firstSuccessIdx+1, firstSuccessIdx, timings[firstSuccessIdx])
|
||||
}
|
||||
|
||||
return firstSuccessIdx >= 0, report
|
||||
}
|
||||
|
||||
// testConcurrentRequests 并发请求
|
||||
func testConcurrentRequests(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, 5)
|
||||
timings := make([]time.Duration, 5)
|
||||
mu := sync.Mutex{}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
start := time.Now()
|
||||
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
mu.Lock()
|
||||
results[idx] = err == nil && resp != nil
|
||||
timings[idx] = elapsed
|
||||
mu.Unlock()
|
||||
|
||||
fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", idx+1, elapsed, map[bool]string{true: "✓", false: "✗"}[results[idx]])
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
successCount := 0
|
||||
for _, ok := range results {
|
||||
if ok {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
return successCount > 0, fmt.Sprintf("%d/5 并发请求成功", successCount)
|
||||
}
|
||||
|
||||
// testWarmupThenRequest 预热测试
|
||||
func testWarmupThenRequest(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
// 第 1 步:预热 - 调用 LoadCodeAssist(获取项目信息)
|
||||
fmt.Println(" [Warmup] 调用 LoadCodeAssist 预热...")
|
||||
warmupStart := time.Now()
|
||||
_, _, warmupErr := client.LoadCodeAssist(ctx, token)
|
||||
warmupElapsed := time.Since(warmupStart)
|
||||
fmt.Printf(" [Warmup] 耗时 %v, 状态: %v\n", warmupElapsed, map[bool]string{true: "✓", false: "✗"}[warmupErr == nil])
|
||||
|
||||
// 第 2 步:实际请求
|
||||
fmt.Println(" [Request] 发送业务请求...")
|
||||
reqStart := time.Now()
|
||||
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
|
||||
reqElapsed := time.Since(reqStart)
|
||||
success := err == nil && resp != nil
|
||||
fmt.Printf(" [Request] 耗时 %v, 状态: %v\n", reqElapsed, map[bool]string{true: "✓", false: "✗"}[success])
|
||||
|
||||
return success, fmt.Sprintf("预热 %v + 请求 %v = 总耗时 %v",
|
||||
warmupElapsed, reqElapsed, warmupElapsed+reqElapsed)
|
||||
}
|
||||
|
||||
// testDelayedRequest 延迟请求
|
||||
func testDelayedRequest(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println(" 等待 5 秒...")
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
start := time.Now()
|
||||
resp, _, err := client.FetchAvailableModels(ctx, token, projectID)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
success := err == nil && resp != nil
|
||||
return success, fmt.Sprintf("延迟 5s 后请求 - 耗时 %v, 状态: %v", elapsed, map[bool]string{true: "✓", false: "✗"}[success])
|
||||
}
|
||||
|
||||
// testOAuthTokenRefresh OAuth Token 刷新测试
|
||||
func testOAuthTokenRefresh(ctx context.Context, refreshToken string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
tokenInfo, err := client.RefreshToken(ctx, refreshToken, false)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("Token 刷新失败 (%v): %v", elapsed, err)
|
||||
}
|
||||
|
||||
return true, fmt.Sprintf("✓ Token 刷新成功 - 耗时 %v, 新 Token 有效期: %d 秒",
|
||||
elapsed, tokenInfo.ExpiresIn)
|
||||
}
|
||||
|
||||
// testAccountInitializationWarmup 账号初始化预热
|
||||
func testAccountInitializationWarmup(ctx context.Context, token, projectID string) (bool, string) {
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("创建客户端失败: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println(" 执行完整的账号初始化流程...")
|
||||
|
||||
// 1. GetUserInfo
|
||||
fmt.Println(" 1. GetUserInfo...")
|
||||
start := time.Now()
|
||||
_, err1 := client.GetUserInfo(ctx, token)
|
||||
fmt.Printf(" 耗时: %v\n", time.Since(start))
|
||||
|
||||
// 2. LoadCodeAssist
|
||||
fmt.Println(" 2. LoadCodeAssist...")
|
||||
start = time.Now()
|
||||
_, _, err2 := client.LoadCodeAssist(ctx, token)
|
||||
fmt.Printf(" 耗时: %v\n", time.Since(start))
|
||||
|
||||
// 3. FetchAvailableModels
|
||||
fmt.Println(" 3. FetchAvailableModels...")
|
||||
start = time.Now()
|
||||
_, _, err3 := client.FetchAvailableModels(ctx, token, projectID)
|
||||
elapsed := time.Since(start)
|
||||
fmt.Printf(" 耗时: %v\n", elapsed)
|
||||
|
||||
success := err1 == nil && err2 == nil && err3 == nil
|
||||
return success, fmt.Sprintf("账号初始化预热 - 状态: %v", map[bool]string{true: "✓", false: "✗"}[success])
|
||||
}
|
||||
|
||||
func main() {
|
||||
accessToken := flag.String("token", "", "OAuth access token")
|
||||
projectID := flag.String("project", "", "Project ID")
|
||||
refreshToken := flag.String("refresh", "", "Refresh token (optional)")
|
||||
testName := flag.String("test", "all", "测试名称 (all, single_request, sequential_requests, etc.)")
|
||||
flag.Parse()
|
||||
|
||||
if *accessToken == "" || *projectID == "" {
|
||||
log.Fatal("缺少必需参数: -token 和 -project")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
fmt.Println("\n" + repeatStr("=", 80))
|
||||
fmt.Println("Antigravity 账号初始化诊断测试套件")
|
||||
fmt.Println(repeatStr("=", 80) + "\n")
|
||||
|
||||
// Token 刷新测试
|
||||
if *refreshToken != "" {
|
||||
fmt.Println("[Token 刷新测试]")
|
||||
_, report := testOAuthTokenRefresh(ctx, *refreshToken)
|
||||
fmt.Printf("%s\n\n", report)
|
||||
}
|
||||
|
||||
// 账号初始化预热测试
|
||||
fmt.Println("[账号初始化预热]")
|
||||
_, report := testAccountInitializationWarmup(ctx, *accessToken, *projectID)
|
||||
fmt.Printf("%s\n\n", report)
|
||||
|
||||
// 运行指定的测试
|
||||
if *testName == "all" {
|
||||
for _, scenario := range scenarios {
|
||||
fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description)
|
||||
_, report := scenario.testFunc(ctx, *accessToken, *projectID)
|
||||
fmt.Printf("结果: %s\n\n", report)
|
||||
}
|
||||
} else {
|
||||
found := false
|
||||
for _, scenario := range scenarios {
|
||||
if scenario.name == *testName {
|
||||
found = true
|
||||
fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description)
|
||||
_, report := scenario.testFunc(ctx, *accessToken, *projectID)
|
||||
fmt.Printf("结果: %s\n\n", report)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Fatalf("未找到测试: %s", *testName)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(repeatStr("=", 80))
|
||||
fmt.Println("诊断完成")
|
||||
fmt.Println(repeatStr("=", 80))
|
||||
}
|
||||
|
||||
func repeatStr(s string, count int) string {
|
||||
result := ""
|
||||
for i := 0; i < count; i++ {
|
||||
result += s
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -7,13 +7,14 @@
|
||||
package language_server_pb
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
_ "google.golang.org/protobuf/types/known/emptypb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@ -5,12 +5,13 @@
|
||||
package language_server_pbconnect
|
||||
|
||||
import (
|
||||
connect "connectrpc.com/connect"
|
||||
context "context"
|
||||
errors "errors"
|
||||
language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
|
||||
http "net/http"
|
||||
strings "strings"
|
||||
|
||||
connect "connectrpc.com/connect"
|
||||
language_server_pb "github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file and the connect package are
|
||||
|
||||
@ -15,8 +15,7 @@ func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAut
|
||||
}
|
||||
|
||||
type AntigravityGenerateAuthURLRequest struct {
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
IsEnterprise bool `json:"is_enterprise"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates Google OAuth authorization URL
|
||||
@ -28,7 +27,7 @@ func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.IsEnterprise)
|
||||
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
|
||||
if err != nil {
|
||||
response.InternalError(c, "生成授权链接失败: "+err.Error())
|
||||
return
|
||||
@ -71,7 +70,6 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
type AntigravityRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
IsEnterprise bool `json:"is_enterprise"`
|
||||
}
|
||||
|
||||
// RefreshToken validates an Antigravity refresh token and returns full token info
|
||||
@ -83,7 +81,7 @@ func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID, req.IsEnterprise)
|
||||
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -1,267 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AntigravityHTTPHandler 处理下游客户端的 HTTP 请求
|
||||
// 内部调用 LanguageServerService,再转发到上游 API
|
||||
type AntigravityHTTPHandler struct {
|
||||
langServerService *service.LanguageServerService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewAntigravityHTTPHandler(
|
||||
langServerService *service.LanguageServerService,
|
||||
logger *slog.Logger,
|
||||
) *AntigravityHTTPHandler {
|
||||
return &AntigravityHTTPHandler{
|
||||
langServerService: langServerService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cascade 流程 API
|
||||
// ============================================================================
|
||||
|
||||
// StartCascadeRequest HTTP 请求格式
|
||||
type StartCascadeRequest struct {
|
||||
Model string `json:"model"` // 模型名称
|
||||
SystemPrompt string `json:"system_prompt"` // 系统提示
|
||||
Metadata map[string]string `json:"metadata"` // 设备指纹等伪装信息
|
||||
}
|
||||
|
||||
// StartCascadeResponse HTTP 响应格式
|
||||
type StartCascadeResponse struct {
|
||||
CascadeID string `json:"cascade_id"`
|
||||
}
|
||||
|
||||
// POST /api/v1/cascade/start
|
||||
// 启动新的 Cascade Agent 会话
|
||||
func (h *AntigravityHTTPHandler) StartCascade(c *gin.Context) {
|
||||
var req StartCascadeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("invalid request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid request: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 提取 OAuth token
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用内部 LanguageServerService
|
||||
cascadeID, err := h.langServerService.StartCascade(
|
||||
c.Request.Context(),
|
||||
req.Model,
|
||||
req.SystemPrompt,
|
||||
req.Metadata,
|
||||
token,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("start cascade failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("cascade started", "cascade_id", cascadeID, "model", req.Model)
|
||||
|
||||
c.JSON(http.StatusOK, StartCascadeResponse{
|
||||
CascadeID: cascadeID,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
// SendUserMessageRequest HTTP 请求格式
|
||||
type SendUserMessageRequest struct {
|
||||
CascadeID string `json:"cascade_id"` // 会话 ID
|
||||
Message string `json:"message"` // 用户消息
|
||||
Context map[string]string `json:"context"` // 上下文(可选)
|
||||
}
|
||||
|
||||
// CascadeUpdate 流式响应格式(Server-Sent Events)
|
||||
type CascadeUpdate struct {
|
||||
Type string `json:"type"` // "message_delta", "tool_call", etc.
|
||||
Payload string `json:"payload"` // JSON 格式的负载
|
||||
}
|
||||
|
||||
// POST /api/v1/cascade/message (流式)
|
||||
// 发送用户消息,接收流式更新
|
||||
func (h *AntigravityHTTPHandler) SendUserMessage(c *gin.Context) {
|
||||
var req SendUserMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("invalid request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid request: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 提取 OAuth token
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "missing authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 Server-Sent Events 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 调用内部 LanguageServerService,获取流式响应
|
||||
updateChan, err := h.langServerService.SendUserMessage(
|
||||
c.Request.Context(),
|
||||
req.CascadeID,
|
||||
req.Message,
|
||||
token,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("send user message failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 逐个推送更新到客户端(SSE)
|
||||
for update := range updateChan {
|
||||
data, _ := json.Marshal(update)
|
||||
c.SSEvent("update", string(data))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
h.logger.Info("cascade message processed", "cascade_id", req.CascadeID)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
// POST /api/v1/cascade/cancel
|
||||
// 取消 Cascade 调用
|
||||
func (h *AntigravityHTTPHandler) CancelCascade(c *gin.Context) {
|
||||
var req struct {
|
||||
CascadeID string `json:"cascade_id"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.langServerService.CancelCascade(
|
||||
c.Request.Context(),
|
||||
req.CascadeID,
|
||||
); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("cascade cancelled", "cascade_id", req.CascadeID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 模型配置 API
|
||||
// ============================================================================
|
||||
|
||||
// ModelConfig 模型配置
|
||||
type ModelConfig struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
SupportsThinking bool `json:"supports_thinking"`
|
||||
ThinkingBudget int `json:"thinking_budget,omitempty"`
|
||||
SupportsImages bool `json:"supports_images"`
|
||||
Provider string `json:"provider"` // anthropic, google, openai
|
||||
}
|
||||
|
||||
// GET /api/v1/models
|
||||
// 获取可用模型列表
|
||||
func (h *AntigravityHTTPHandler) GetModels(c *gin.Context) {
|
||||
models, err := h.langServerService.GetAvailableModels(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("get models failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"models": models,
|
||||
"default_model": "claude-opus-4-7",
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 健康检查 API
|
||||
// ============================================================================
|
||||
|
||||
// GET /api/v1/health
|
||||
// 健康检查
|
||||
func (h *AntigravityHTTPHandler) Health(c *gin.Context) {
|
||||
status, err := h.langServerService.GetStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"status": "unhealthy",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": status,
|
||||
"version": "1.0.0",
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
// RegisterRoutes 注册所有 HTTP 路由
|
||||
func (h *AntigravityHTTPHandler) RegisterRoutes(router *gin.Engine) {
|
||||
api := router.Group("/api/v1")
|
||||
|
||||
// Cascade 流程
|
||||
api.POST("/cascade/start", h.StartCascade)
|
||||
api.POST("/cascade/message", h.SendUserMessage)
|
||||
api.POST("/cascade/cancel", h.CancelCascade)
|
||||
|
||||
// 模型列表
|
||||
api.GET("/models", h.GetModels)
|
||||
|
||||
// 健康检查
|
||||
api.GET("/health", h.Health)
|
||||
|
||||
h.logger.Info("antigravity http handler registered",
|
||||
"routes", []string{
|
||||
"/api/v1/cascade/start",
|
||||
"/api/v1/cascade/message",
|
||||
"/api/v1/cascade/cancel",
|
||||
"/api/v1/models",
|
||||
"/api/v1/health",
|
||||
})
|
||||
}
|
||||
@ -270,19 +270,19 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
AvatarURL: "https://cdn.example.com/linuxdo.png",
|
||||
AvatarSource: "remote_url",
|
||||
},
|
||||
identities: []service.UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-21",
|
||||
VerifiedAt: &verifiedAt,
|
||||
Metadata: map[string]any{
|
||||
"username": "linuxdo-handle",
|
||||
"avatar_url": "https://cdn.example.com/linuxdo.png",
|
||||
},
|
||||
identities: []service.UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-21",
|
||||
VerifiedAt: &verifiedAt,
|
||||
Metadata: map[string]any{
|
||||
"username": "linuxdo-handle",
|
||||
"avatar_url": "https://cdn.example.com/linuxdo.png",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@ -1,70 +0,0 @@
|
||||
package antigravity
|
||||
|
||||
import "strings"
|
||||
|
||||
var claudeCodeBuiltinToolNameMap = map[string]string{
|
||||
"read": "Read",
|
||||
"read_file": "Read",
|
||||
"readfile": "Read",
|
||||
"write": "Write",
|
||||
"write_file": "Write",
|
||||
"writefile": "Write",
|
||||
"edit": "Edit",
|
||||
"apply_patch": "Edit",
|
||||
"applypatch": "Edit",
|
||||
"bash": "Bash",
|
||||
"execute_bash": "Bash",
|
||||
"executebash": "Bash",
|
||||
"exec_bash": "Bash",
|
||||
"execbash": "Bash",
|
||||
"glob": "Glob",
|
||||
"list_files": "Glob",
|
||||
"listfiles": "Glob",
|
||||
"grep": "Grep",
|
||||
"search_files": "Grep",
|
||||
"searchfiles": "Grep",
|
||||
"webfetch": "WebFetch",
|
||||
"web_fetch": "WebFetch",
|
||||
"fetch": "WebFetch",
|
||||
"websearch": "WebSearch",
|
||||
"web_search": "WebSearch",
|
||||
"agent": "Agent",
|
||||
"askuserquestion": "AskUserQuestion",
|
||||
"ask_user_question": "AskUserQuestion",
|
||||
"enterplanmode": "EnterPlanMode",
|
||||
"enter_plan_mode": "EnterPlanMode",
|
||||
"exitplanmode": "ExitPlanMode",
|
||||
"exit_plan_mode": "ExitPlanMode",
|
||||
"croncreate": "CronCreate",
|
||||
"cron_create": "CronCreate",
|
||||
"crondelete": "CronDelete",
|
||||
"cron_delete": "CronDelete",
|
||||
"schedulewakeup": "ScheduleWakeup",
|
||||
"schedule_wakeup": "ScheduleWakeup",
|
||||
"sendmessage": "SendMessage",
|
||||
"send_message": "SendMessage",
|
||||
"skill": "Skill",
|
||||
"taskcreate": "TaskCreate",
|
||||
"task_create": "TaskCreate",
|
||||
"tasklist": "TaskList",
|
||||
"task_list": "TaskList",
|
||||
"taskoutput": "TaskOutput",
|
||||
"task_output": "TaskOutput",
|
||||
"taskstop": "TaskStop",
|
||||
"task_stop": "TaskStop",
|
||||
"taskupdate": "TaskUpdate",
|
||||
"task_update": "TaskUpdate",
|
||||
}
|
||||
|
||||
func normalizeClaudeCodeToolName(name string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if mapped, ok := claudeCodeBuiltinToolNameMap[strings.ToLower(trimmed)]; ok {
|
||||
return mapped
|
||||
}
|
||||
|
||||
return trimmed
|
||||
}
|
||||
@ -1,160 +0,0 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeCodeToolName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{name: "read alias", input: "read_file", expected: "Read"},
|
||||
{name: "grep alias", input: "search_files", expected: "Grep"},
|
||||
{name: "webfetch alias", input: "fetch", expected: "WebFetch"},
|
||||
{name: "plan alias", input: "enter_plan_mode", expected: "EnterPlanMode"},
|
||||
{name: "native passthrough", input: "TaskUpdate", expected: "TaskUpdate"},
|
||||
{name: "mcp passthrough", input: "mcp__github__list_prs", expected: "mcp__github__list_prs"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.expected, normalizeClaudeCodeToolName(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPartsNormalizesClaudeCodeToolNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
toolIDToName := make(map[string]string)
|
||||
assistantParts, stripped, err := buildParts(json.RawMessage(`[
|
||||
{"type":"tool_use","id":"tool-1","name":"read_file","input":{"file_path":"/tmp/demo.txt"}}
|
||||
]`), toolIDToName, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, stripped)
|
||||
require.Len(t, assistantParts, 1)
|
||||
require.NotNil(t, assistantParts[0].FunctionCall)
|
||||
require.Equal(t, "Read", assistantParts[0].FunctionCall.Name)
|
||||
require.Equal(t, "Read", toolIDToName["tool-1"])
|
||||
|
||||
userParts, stripped, err := buildParts(json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"tool-1","content":[{"type":"text","text":"ok"}]}
|
||||
]`), toolIDToName, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, stripped)
|
||||
require.Len(t, userParts, 1)
|
||||
require.NotNil(t, userParts[0].FunctionResponse)
|
||||
require.Equal(t, "Read", userParts[0].FunctionResponse.Name)
|
||||
}
|
||||
|
||||
func TestBuildToolsNormalizesClaudeCodeBuiltinNamesOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := buildTools([]ClaudeTool{
|
||||
{
|
||||
Name: "search_files",
|
||||
Description: "Search the workspace",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mcp__github__list_prs",
|
||||
Description: "List pull requests",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, result, 1)
|
||||
require.Len(t, result[0].FunctionDeclarations, 2)
|
||||
require.Equal(t, "Grep", result[0].FunctionDeclarations[0].Name)
|
||||
require.Equal(t, "mcp__github__list_prs", result[0].FunctionDeclarations[1].Name)
|
||||
}
|
||||
|
||||
func TestNonStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
processor := NewNonStreamingProcessor()
|
||||
response := processor.Process(&GeminiResponse{
|
||||
Candidates: []GeminiCandidate{
|
||||
{
|
||||
Content: &GeminiContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: "web_fetch",
|
||||
Args: map[string]any{"url": "https://example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FinishReason: "STOP",
|
||||
},
|
||||
},
|
||||
}, "resp-1", "claude-sonnet-4-5")
|
||||
|
||||
require.Len(t, response.Content, 1)
|
||||
require.Equal(t, "tool_use", response.Content[0].Type)
|
||||
require.Equal(t, "WebFetch", response.Content[0].Name)
|
||||
require.True(t, strings.HasPrefix(response.Content[0].ID, "WebFetch-"))
|
||||
require.NotNil(t, response.Content[0].Caller)
|
||||
require.Equal(t, "direct", response.Content[0].Caller.Type)
|
||||
require.Equal(t, "tool_use", response.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamingProcessorNormalizesClaudeCodeToolName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
processor := NewStreamingProcessor("claude-sonnet-4-5")
|
||||
output := processor.processFunctionCall(&GeminiFunctionCall{
|
||||
Name: "search_files",
|
||||
Args: map[string]any{"pattern": "TODO"},
|
||||
}, "")
|
||||
|
||||
events := parseSSEDataEvents(t, output)
|
||||
require.Len(t, events, 3)
|
||||
|
||||
contentBlock, ok := events[0]["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "tool_use", contentBlock["type"])
|
||||
require.Equal(t, "Grep", contentBlock["name"])
|
||||
|
||||
toolID, ok := contentBlock["id"].(string)
|
||||
require.True(t, ok)
|
||||
require.True(t, strings.HasPrefix(toolID, "Grep-"))
|
||||
|
||||
caller, ok := contentBlock["caller"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "direct", caller["type"])
|
||||
}
|
||||
|
||||
func parseSSEDataEvents(t *testing.T, payload []byte) []map[string]any {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(string(payload), "\n")
|
||||
events := make([]map[string]any, 0)
|
||||
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &event))
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
@ -16,7 +16,6 @@ type ClaudeRequest struct {
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Tools []ClaudeTool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@ -73,10 +72,9 @@ type ContentBlock struct {
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Caller *ToolCaller `json:"caller,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
@ -116,15 +114,9 @@ type ClaudeContentItem struct {
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Caller *ToolCaller `json:"caller,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCaller Claude Code tool_use 调用来源
|
||||
type ToolCaller struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsage Claude 用量统计
|
||||
|
||||
@ -318,17 +318,16 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
statusCode >= 500
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token。
|
||||
// isEnterprise=true 时使用企业 OAuth client_id/secret。
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, isEnterprise bool) (*TokenResponse, error) {
|
||||
creds, err := GetClientCredentials(isEnterprise)
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", creds.ClientID)
|
||||
params.Set("client_secret", creds.ClientSecret)
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
@ -363,17 +362,16 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string, is
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 access_token。
|
||||
// isEnterprise=true 时使用企业 OAuth client_id/secret。
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterprise bool) (*TokenResponse, error) {
|
||||
creds, err := GetClientCredentials(isEnterprise)
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", creds.ClientID)
|
||||
params.Set("client_secret", creds.ClientSecret)
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
@ -406,39 +404,6 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string, isEnterp
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshTokenAuto 自动判定账号类型。
|
||||
// 先用个人凭证刷新;若 Google 返回 invalid_client/unauthorized_client(client 不匹配),
|
||||
// 再用企业凭证重试。返回 token 和最终判定的 isEnterprise 标志。
|
||||
//
|
||||
// 其他错误(invalid_grant、网络错误等)直接返回,不重试。
|
||||
func (c *Client) RefreshTokenAuto(ctx context.Context, refreshToken string) (*TokenResponse, bool, error) {
|
||||
tok, err := c.RefreshToken(ctx, refreshToken, false)
|
||||
if err == nil {
|
||||
return tok, false, nil
|
||||
}
|
||||
if !isClientMismatchError(err) {
|
||||
return nil, false, err
|
||||
}
|
||||
tok, err2 := c.RefreshToken(ctx, refreshToken, true)
|
||||
if err2 == nil {
|
||||
return tok, true, nil
|
||||
}
|
||||
// 企业也失败:返回合并后的诊断错误
|
||||
return nil, false, fmt.Errorf("auto-detect refresh failed: personal=%v enterprise=%v", err, err2)
|
||||
}
|
||||
|
||||
// isClientMismatchError 判断是否为 OAuth client 不匹配导致的错误。
|
||||
// 只有这种错误才会触发"切换账号类型重试"。
|
||||
func isClientMismatchError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "invalid_client") ||
|
||||
strings.Contains(msg, "unauthorized_client") ||
|
||||
strings.Contains(msg, "client_id")
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
|
||||
|
||||
@ -117,8 +117,7 @@ type GeminiToolConfig struct {
|
||||
|
||||
// GeminiFunctionCallingConfig 函数调用配置
|
||||
type GeminiFunctionCallingConfig struct {
|
||||
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE, ANY
|
||||
AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
|
||||
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
|
||||
}
|
||||
|
||||
// GeminiSafetySetting Gemini 安全设置
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -23,22 +22,16 @@ const (
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// 个人账号 OAuth 凭证(isGcpTos=false,免费 Gemini Code Assist)
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
|
||||
// AntigravityOAuthClientSecretEnv 是个人账号 OAuth client_secret 的环境变量名。
|
||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 企业账号 OAuth 凭证(isGcpTos=true,Google Cloud / Workspace 用户)
|
||||
EnterpriseClientID = "884354919052-36trc1jjb3tguiac32ov6cod268c5blh.apps.googleusercontent.com"
|
||||
|
||||
// AntigravityEnterpriseOAuthClientSecretEnv 是企业账号 OAuth client_secret 的环境变量名。
|
||||
AntigravityEnterpriseOAuthClientSecretEnv = "ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
|
||||
// OAuth scopes(企业和个人共用)
|
||||
// OAuth scopes
|
||||
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
|
||||
"https://www.googleapis.com/auth/userinfo.email " +
|
||||
"https://www.googleapis.com/auth/userinfo.profile " +
|
||||
@ -53,18 +46,15 @@ const (
|
||||
|
||||
// Antigravity API 端点
|
||||
antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com"
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.6(product.json ideVersion)
|
||||
var defaultUserAgentVersion = "1.20.6"
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
|
||||
var defaultUserAgentVersion = "1.21.9"
|
||||
|
||||
// defaultClientSecret 个人账号 client_secret,可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 覆盖
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
// defaultEnterpriseClientSecret 企业账号 client_secret,可通过环境变量 ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET 覆盖
|
||||
var defaultEnterpriseClientSecret = "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"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
||||
@ -74,58 +64,14 @@ func init() {
|
||||
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
}
|
||||
if secret := os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv); secret != "" {
|
||||
defaultEnterpriseClientSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAgent 返回当前配置的 User-Agent(自动检测平台,匹配真实 IDE 行为)
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s %s/%s", defaultUserAgentVersion, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// ClientCredentials 持有一对 OAuth client_id/secret
|
||||
type ClientCredentials struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
// GetClientCredentials 根据账号类型返回对应的 OAuth 凭证。
|
||||
// isEnterprise=true 时使用企业凭证(isGcpTos=true),否则使用个人凭证。
|
||||
func GetClientCredentials(isEnterprise bool) (ClientCredentials, error) {
|
||||
if isEnterprise {
|
||||
secret := strings.TrimSpace(os.Getenv(AntigravityEnterpriseOAuthClientSecretEnv))
|
||||
if secret == "" {
|
||||
secret = strings.TrimSpace(defaultEnterpriseClientSecret)
|
||||
}
|
||||
if secret == "" {
|
||||
return ClientCredentials{}, infraerrors.Newf(http.StatusBadRequest,
|
||||
"ANTIGRAVITY_ENTERPRISE_OAUTH_CLIENT_SECRET_MISSING",
|
||||
"missing enterprise oauth client_secret; set %s", AntigravityEnterpriseOAuthClientSecretEnv)
|
||||
}
|
||||
return ClientCredentials{ClientID: EnterpriseClientID, ClientSecret: secret}, nil
|
||||
}
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return ClientCredentials{}, err
|
||||
}
|
||||
return ClientCredentials{ClientID: ClientID, ClientSecret: secret}, nil
|
||||
}
|
||||
|
||||
// BaseURLsForAccount 根据 isGcpTos 返回有序 URL 列表。
|
||||
// 企业账号(isGcpTos=true)优先走 prod;个人账号优先走 daily(与真实 IDE 一致)。
|
||||
func BaseURLsForAccount(isGcpTos bool) []string {
|
||||
if isGcpTos {
|
||||
return []string{antigravityProdBaseURL, antigravityDailyBaseURL}
|
||||
}
|
||||
return []string{antigravityDailyBaseURL, antigravityProdBaseURL}
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if secret := strings.TrimSpace(os.Getenv(AntigravityOAuthClientSecretEnv)); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
return secret, nil
|
||||
}
|
||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
@ -265,7 +211,6 @@ type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
IsEnterprise bool `json:"is_enterprise,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
@ -380,15 +325,10 @@ func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL 构建 Google OAuth 授权 URL。
|
||||
// isEnterprise=true 时使用企业 client_id;否则使用个人 client_id。
|
||||
func BuildAuthorizationURL(state, codeChallenge string, isEnterprise bool) string {
|
||||
clientID := ClientID
|
||||
if isEnterprise {
|
||||
clientID = EnterpriseClientID
|
||||
}
|
||||
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
|
||||
func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", Scopes)
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetClientSecret_ReadsRuntimeEnvironment(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "runtime-secret")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret returned error: %v", err)
|
||||
}
|
||||
if secret != "runtime-secret" {
|
||||
t.Fatalf("unexpected secret: got %q want %q", secret, "runtime-secret")
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,12 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -18,16 +16,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
sessionRandMutex sync.Mutex
|
||||
legacyMetadataUserIDSessionPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[a-fA-F0-9-]*_session_([a-fA-F0-9-]{36})$`)
|
||||
plainSessionIDPattern = regexp.MustCompile(`^(session_)?[a-fA-F0-9-]{36}$`)
|
||||
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
sessionRandMutex sync.Mutex
|
||||
)
|
||||
|
||||
type claudeMetadataUserIDPayload struct {
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// generateStableSessionID 基于用户消息内容生成稳定的 session ID
|
||||
func generateStableSessionID(contents []GeminiContent) string {
|
||||
// 查找第一个 user 消息的文本
|
||||
@ -47,82 +39,12 @@ func generateStableSessionID(contents []GeminiContent) string {
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// EnsureGeminiRequestSessionID fills request.sessionId when the caller omitted it.
|
||||
// preferredSessionID wins; otherwise we derive a stable value from the first user turn.
|
||||
func EnsureGeminiRequestSessionID(body []byte, preferredSessionID string) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if raw, ok := payload["sessionId"].(string); ok && strings.TrimSpace(raw) != "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(preferredSessionID)
|
||||
if sessionID == "" {
|
||||
var req GeminiRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = generateStableSessionID(req.Contents)
|
||||
}
|
||||
if sessionID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
payload["sessionId"] = sessionID
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func extractSessionIDFromMetadataUserID(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, "{") {
|
||||
var payload claudeMetadataUserIDPayload
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err == nil {
|
||||
return strings.TrimSpace(payload.SessionID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
if matches := legacyMetadataUserIDSessionPattern.FindStringSubmatch(raw); len(matches) == 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
|
||||
if plainSessionIDPattern.MatchString(raw) {
|
||||
return raw
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveClaudeRequestSessionID(metadata *ClaudeMetadata, preferredSessionID string, contents []GeminiContent) string {
|
||||
if metadata != nil {
|
||||
if sessionID := extractSessionIDFromMetadataUserID(metadata.UserID); sessionID != "" {
|
||||
return sessionID
|
||||
}
|
||||
}
|
||||
|
||||
if sessionID := strings.TrimSpace(preferredSessionID); sessionID != "" {
|
||||
return sessionID
|
||||
}
|
||||
|
||||
return generateStableSessionID(contents)
|
||||
}
|
||||
|
||||
type TransformOptions struct {
|
||||
EnableIdentityPatch bool
|
||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
|
||||
IdentityPatch string
|
||||
EnableMCPXML bool
|
||||
// PreferredSessionID 可选:当 metadata.user_id 不可用于恢复真实会话时,
|
||||
// 允许调用方显式指定 Antigravity 上游 request.sessionId。
|
||||
PreferredSessionID string
|
||||
}
|
||||
|
||||
func DefaultTransformOptions() TransformOptions {
|
||||
@ -163,24 +85,12 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
||||
|
||||
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
|
||||
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
|
||||
normalizedReq, err := normalizeClaudeRequestForAntigravity(claudeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("normalize messages: %w", err)
|
||||
}
|
||||
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否有 web_search 工具
|
||||
hasWebSearchTool := hasWebSearchTool(normalizedReq.Tools)
|
||||
// requestType 映射策略:
|
||||
// - Gemini 模型: "agent"(与 Antigravity 官方客户端一致)
|
||||
// - Claude 模型: 不设置(避免 Google 后端路由到容量受限的 agent 池,降低 503 率)
|
||||
// - web_search: "web_search"(触发 Google 搜索增强路由)
|
||||
hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
|
||||
requestType := "agent"
|
||||
if strings.HasPrefix(mappedModel, "claude-") {
|
||||
requestType = "" // Claude 模型走默认容量池,避免 agent 池 503
|
||||
}
|
||||
targetModel := mappedModel
|
||||
if hasWebSearchTool {
|
||||
requestType = "web_search"
|
||||
@ -190,27 +100,27 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
}
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := normalizedReq.Thinking != nil && (normalizedReq.Thinking.Type == "enabled" || normalizedReq.Thinking.Type == "adaptive")
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, strippedThinking, err := buildContents(normalizedReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build contents: %w", err)
|
||||
}
|
||||
|
||||
// 2. 构建 systemInstruction(使用 targetModel 而非原始请求模型,确保身份注入基于最终模型)
|
||||
systemInstruction := buildSystemInstruction(normalizedReq.System, targetModel, opts, normalizedReq.Tools)
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, targetModel, opts, claudeReq.Tools)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
reqForConfig := normalizedReq
|
||||
reqForConfig := claudeReq
|
||||
if strippedThinking {
|
||||
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
|
||||
// disable upstream thinking mode to avoid signature/structure validation errors.
|
||||
reqCopy := *normalizedReq
|
||||
reqCopy := *claudeReq
|
||||
reqCopy.Thinking = nil
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
@ -222,24 +132,19 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
generationConfig := buildGenerationConfig(reqForConfig)
|
||||
|
||||
// 4. 构建 tools
|
||||
// 对 Claude / Gemini 模型都保留 functionDeclarations:
|
||||
// - Claude 分支如果完全丢掉 tools,模型只能看到消息历史中的 tool_use/tool_result,
|
||||
// 但拿不到当前可用工具定义,容易导致“能还原名字但不会继续发工具调用”。
|
||||
// - Gemini 分支原本就依赖 functionDeclarations 触发 function_call。
|
||||
isClaudeModel := strings.HasPrefix(targetModel, "claude-")
|
||||
tools := buildTools(normalizedReq.Tools)
|
||||
tools := buildTools(claudeReq.Tools)
|
||||
|
||||
// 5. 构建内部请求
|
||||
innerRequest := GeminiRequest{
|
||||
Contents: contents,
|
||||
SessionID: resolveClaudeRequestSessionID(normalizedReq.Metadata, opts.PreferredSessionID, contents),
|
||||
}
|
||||
|
||||
// Gemini 分支保持默认 VALIDATED;
|
||||
// Claude 分支仅在声明了工具时附带 toolConfig,避免再把工具能力静默丢失。
|
||||
defaultValidated := !isClaudeModel || len(tools) > 0
|
||||
if toolConfig := buildToolConfig(normalizedReq.ToolChoice, defaultValidated); toolConfig != nil {
|
||||
innerRequest.ToolConfig = toolConfig
|
||||
Contents: contents,
|
||||
// 总是设置 toolConfig,与官方客户端一致
|
||||
ToolConfig: &GeminiToolConfig{
|
||||
FunctionCallingConfig: &GeminiFunctionCallingConfig{
|
||||
Mode: "VALIDATED",
|
||||
},
|
||||
},
|
||||
// 总是生成 sessionId,基于用户消息内容
|
||||
SessionID: generateStableSessionID(contents),
|
||||
}
|
||||
|
||||
if systemInstruction != nil {
|
||||
@ -252,6 +157,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
innerRequest.Tools = tools
|
||||
}
|
||||
|
||||
// 如果提供了 metadata.user_id,优先使用
|
||||
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
|
||||
innerRequest.SessionID = claudeReq.Metadata.UserID
|
||||
}
|
||||
|
||||
// 6. 包装为 v1internal 请求
|
||||
v1Req := V1InternalRequest{
|
||||
Project: projectID,
|
||||
@ -265,319 +175,6 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
|
||||
return json.Marshal(v1Req)
|
||||
}
|
||||
|
||||
const (
|
||||
maxAntigravityToolDescriptionChars = 400
|
||||
maxAntigravitySchemaDescriptionChars = 200
|
||||
maxAntigravityToolResultChars = 200000
|
||||
)
|
||||
|
||||
func normalizeClaudeRequestForAntigravity(claudeReq *ClaudeRequest) (*ClaudeRequest, error) {
|
||||
if claudeReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reqCopy := *claudeReq
|
||||
if len(claudeReq.Messages) == 0 {
|
||||
return &reqCopy, nil
|
||||
}
|
||||
|
||||
normalizedMessages, err := normalizeClaudeMessagesForAntigravity(claudeReq.Messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqCopy.Messages = normalizedMessages
|
||||
return &reqCopy, nil
|
||||
}
|
||||
|
||||
func normalizeClaudeMessagesForAntigravity(messages []ClaudeMessage) ([]ClaudeMessage, error) {
|
||||
normalized := make([]ClaudeMessage, 0, len(messages)+1)
|
||||
pendingToolUseIDs := make([]string, 0)
|
||||
|
||||
for _, message := range messages {
|
||||
blocks, hasBlocks := parseClaudeMessageBlocks(message.Content)
|
||||
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if len(pendingToolUseIDs) > 0 {
|
||||
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, synthetic)
|
||||
pendingToolUseIDs = pendingToolUseIDs[:0]
|
||||
}
|
||||
|
||||
if !hasBlocks {
|
||||
normalized = append(normalized, cloneClaudeMessage(message))
|
||||
continue
|
||||
}
|
||||
|
||||
stripped := stripNonToolPartsAfterToolUse(reorderAssistantThinkingBlocks(blocks))
|
||||
pendingToolUseIDs = append(pendingToolUseIDs, collectToolUseIDs(stripped)...)
|
||||
|
||||
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, stripped)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, nextMessage)
|
||||
|
||||
case "user":
|
||||
if !hasBlocks {
|
||||
if len(pendingToolUseIDs) > 0 {
|
||||
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, synthetic)
|
||||
pendingToolUseIDs = pendingToolUseIDs[:0]
|
||||
}
|
||||
normalized = append(normalized, cloneClaudeMessage(message))
|
||||
continue
|
||||
}
|
||||
|
||||
parts := cloneJSONBlocks(blocks)
|
||||
if len(pendingToolUseIDs) > 0 {
|
||||
toolResults, nonToolResults := partitionToolResultBlocks(parts)
|
||||
existingIDs := collectToolResultIDs(toolResults)
|
||||
missingIDs := diffStringSlice(pendingToolUseIDs, existingIDs)
|
||||
if len(missingIDs) > 0 {
|
||||
parts = append(append(toolResults, buildSyntheticToolResultBlocks(missingIDs)...), nonToolResults...)
|
||||
}
|
||||
pendingToolUseIDs = pendingToolUseIDs[:0]
|
||||
}
|
||||
|
||||
toolResults, nonToolResults := partitionToolResultBlocks(parts)
|
||||
switch {
|
||||
case len(toolResults) == 0:
|
||||
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, nextMessage)
|
||||
case len(nonToolResults) == 0:
|
||||
nextMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, nextMessage)
|
||||
default:
|
||||
toolResultMessage, err := buildClaudeMessageWithBlocks(message.Role, toolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userTextMessage, err := buildClaudeMessageWithBlocks(message.Role, nonToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, toolResultMessage, userTextMessage)
|
||||
}
|
||||
|
||||
default:
|
||||
normalized = append(normalized, cloneClaudeMessage(message))
|
||||
}
|
||||
}
|
||||
|
||||
if len(pendingToolUseIDs) > 0 {
|
||||
synthetic, err := buildSyntheticToolResultMessage(pendingToolUseIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, synthetic)
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func parseClaudeMessageBlocks(content json.RawMessage) ([]map[string]any, bool) {
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(content, &blocks); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return blocks, true
|
||||
}
|
||||
|
||||
func cloneClaudeMessage(message ClaudeMessage) ClaudeMessage {
|
||||
cloned := ClaudeMessage{Role: message.Role}
|
||||
if len(message.Content) > 0 {
|
||||
cloned.Content = append(json.RawMessage(nil), message.Content...)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneJSONBlocks(blocks []map[string]any) []map[string]any {
|
||||
cloned := make([]map[string]any, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
cloned = append(cloned, cloneJSONMap(block))
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneJSONMap(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
if cloned, ok := deepCopy(block).(map[string]any); ok {
|
||||
return cloned
|
||||
}
|
||||
fallback := make(map[string]any, len(block))
|
||||
for key, value := range block {
|
||||
fallback[key] = value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func buildClaudeMessageWithBlocks(role string, blocks []map[string]any) (ClaudeMessage, error) {
|
||||
payload, err := json.Marshal(blocks)
|
||||
if err != nil {
|
||||
return ClaudeMessage{}, fmt.Errorf("marshal %s message blocks: %w", role, err)
|
||||
}
|
||||
return ClaudeMessage{Role: role, Content: payload}, nil
|
||||
}
|
||||
|
||||
func buildSyntheticToolResultMessage(toolUseIDs []string) (ClaudeMessage, error) {
|
||||
return buildClaudeMessageWithBlocks("user", buildSyntheticToolResultBlocks(toolUseIDs))
|
||||
}
|
||||
|
||||
func buildSyntheticToolResultBlocks(toolUseIDs []string) []map[string]any {
|
||||
blocks := make([]map[string]any, 0, len(toolUseIDs))
|
||||
for _, toolUseID := range toolUseIDs {
|
||||
if strings.TrimSpace(toolUseID) == "" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"is_error": true,
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "[tool_result missing; tool execution interrupted]",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func reorderAssistantThinkingBlocks(blocks []map[string]any) []map[string]any {
|
||||
thinkingBlocks := make([]map[string]any, 0)
|
||||
otherBlocks := make([]map[string]any, 0, len(blocks))
|
||||
|
||||
for _, block := range blocks {
|
||||
cloned := cloneJSONMap(block)
|
||||
blockType, _ := cloned["type"].(string)
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
delete(cloned, "cache_control")
|
||||
thinkingBlocks = append(thinkingBlocks, cloned)
|
||||
continue
|
||||
}
|
||||
otherBlocks = append(otherBlocks, cloned)
|
||||
}
|
||||
|
||||
if len(thinkingBlocks) == 0 {
|
||||
return otherBlocks
|
||||
}
|
||||
return append(thinkingBlocks, otherBlocks...)
|
||||
}
|
||||
|
||||
func stripNonToolPartsAfterToolUse(blocks []map[string]any) []map[string]any {
|
||||
cleaned := make([]map[string]any, 0, len(blocks))
|
||||
seenToolUse := false
|
||||
|
||||
for _, block := range blocks {
|
||||
blockType, _ := block["type"].(string)
|
||||
if blockType == "tool_use" {
|
||||
seenToolUse = true
|
||||
cleaned = append(cleaned, block)
|
||||
continue
|
||||
}
|
||||
if !seenToolUse {
|
||||
cleaned = append(cleaned, block)
|
||||
continue
|
||||
}
|
||||
if isIgnorableTrailingTextBlock(block) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func isIgnorableTrailingTextBlock(block map[string]any) bool {
|
||||
blockType, _ := block["type"].(string)
|
||||
if blockType != "text" {
|
||||
return false
|
||||
}
|
||||
text, _ := block["text"].(string)
|
||||
trimmed := strings.TrimSpace(text)
|
||||
return trimmed == "" || trimmed == "(no content)"
|
||||
}
|
||||
|
||||
func collectToolUseIDs(blocks []map[string]any) []string {
|
||||
ids := make([]string, 0)
|
||||
for _, block := range blocks {
|
||||
blockType, _ := block["type"].(string)
|
||||
if blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
id, _ := block["id"].(string)
|
||||
if strings.TrimSpace(id) != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func collectToolResultIDs(blocks []map[string]any) []string {
|
||||
ids := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
id, _ := block["tool_use_id"].(string)
|
||||
if strings.TrimSpace(id) != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func diffStringSlice(left, right []string) []string {
|
||||
if len(left) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(right))
|
||||
for _, value := range right {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
seen[value] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
diff := make([]string, 0, len(left))
|
||||
for _, value := range left {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
diff = append(diff, value)
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
func partitionToolResultBlocks(blocks []map[string]any) (toolResults []map[string]any, nonToolResults []map[string]any) {
|
||||
toolResults = make([]map[string]any, 0)
|
||||
nonToolResults = make([]map[string]any, 0)
|
||||
for _, block := range blocks {
|
||||
blockType, _ := block["type"].(string)
|
||||
if blockType == "tool_result" {
|
||||
toolResults = append(toolResults, block)
|
||||
continue
|
||||
}
|
||||
nonToolResults = append(nonToolResults, block)
|
||||
}
|
||||
return toolResults, nonToolResults
|
||||
}
|
||||
|
||||
// antigravityIdentity Antigravity identity 提示词
|
||||
const antigravityIdentity = `<identity>
|
||||
You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.
|
||||
@ -877,14 +474,13 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
|
||||
case "tool_use":
|
||||
// 存储 id -> name 映射
|
||||
toolName := normalizeClaudeCodeToolName(block.Name)
|
||||
if block.ID != "" && toolName != "" {
|
||||
toolIDToName[block.ID] = toolName
|
||||
if block.ID != "" && block.Name != "" {
|
||||
toolIDToName[block.ID] = block.Name
|
||||
}
|
||||
|
||||
part := GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: toolName,
|
||||
Name: block.Name,
|
||||
Args: block.Input,
|
||||
ID: block.ID,
|
||||
},
|
||||
@ -901,12 +497,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
|
||||
case "tool_result":
|
||||
// 获取函数名
|
||||
funcName := normalizeClaudeCodeToolName(block.Name)
|
||||
funcName := block.Name
|
||||
if funcName == "" {
|
||||
if name, ok := toolIDToName[block.ToolUseID]; ok {
|
||||
funcName = name
|
||||
} else {
|
||||
funcName = normalizeClaudeCodeToolName(block.ToolUseID)
|
||||
funcName = block.ToolUseID
|
||||
}
|
||||
}
|
||||
|
||||
@ -929,84 +525,47 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
}
|
||||
|
||||
// parseToolResultContent 解析 tool_result 的 content
|
||||
func parseToolResultContent(content json.RawMessage, isError bool) any {
|
||||
func parseToolResultContent(content json.RawMessage, isError bool) string {
|
||||
if len(content) == 0 {
|
||||
return defaultToolResultContent(isError)
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(content, &str); err == nil {
|
||||
if strings.TrimSpace(str) == "" {
|
||||
return defaultToolResultContent(isError)
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
return truncateInlineText(str, maxAntigravityToolResultChars)
|
||||
return str
|
||||
}
|
||||
|
||||
// 优先保留结构化 tool_result,避免上游把内容视为无效的纯文本降级。
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal(content, &arr); err == nil {
|
||||
sanitized := sanitizeToolResultBlocksForAntigravity(arr)
|
||||
if len(sanitized) == 0 {
|
||||
return defaultToolResultContent(isError)
|
||||
var texts []string
|
||||
for _, item := range arr {
|
||||
if text, ok := item["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal(content, &obj); err == nil {
|
||||
sanitized := sanitizeToolResultObjectForAntigravity(obj)
|
||||
if len(sanitized) == 0 {
|
||||
return defaultToolResultContent(isError)
|
||||
result := strings.Join(texts, "\n")
|
||||
if strings.TrimSpace(result) == "" {
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
return sanitized
|
||||
return result
|
||||
}
|
||||
|
||||
// 返回原始 JSON
|
||||
return truncateInlineText(string(content), maxAntigravityToolResultChars)
|
||||
}
|
||||
|
||||
func defaultToolResultContent(isError bool) string {
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
|
||||
func sanitizeToolResultBlocksForAntigravity(blocks []map[string]any) []map[string]any {
|
||||
sanitized := make([]map[string]any, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
if isBase64ImageToolResultBlock(block) {
|
||||
continue
|
||||
}
|
||||
cloned := cloneJSONMap(block)
|
||||
if text, ok := cloned["text"].(string); ok {
|
||||
cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars)
|
||||
}
|
||||
sanitized = append(sanitized, cloned)
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func sanitizeToolResultObjectForAntigravity(block map[string]any) map[string]any {
|
||||
if isBase64ImageToolResultBlock(block) {
|
||||
return nil
|
||||
}
|
||||
cloned := cloneJSONMap(block)
|
||||
if text, ok := cloned["text"].(string); ok {
|
||||
cloned["text"] = truncateInlineText(text, maxAntigravityToolResultChars)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func isBase64ImageToolResultBlock(block map[string]any) bool {
|
||||
blockType, _ := block["type"].(string)
|
||||
if blockType != "image" {
|
||||
return false
|
||||
}
|
||||
source, _ := block["source"].(map[string]any)
|
||||
sourceType, _ := source["type"].(string)
|
||||
return sourceType == "base64"
|
||||
return string(content)
|
||||
}
|
||||
|
||||
// buildGenerationConfig 构建 generationConfig
|
||||
@ -1074,15 +633,6 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
}
|
||||
}
|
||||
config.ThinkingConfig.ThinkingBudget = budget
|
||||
} else if strings.HasSuffix(req.Model, "-thinking") || strings.HasPrefix(req.Model, "claude-sonnet-4-6") {
|
||||
// 自动注入 thinkingConfig 的两种情形(客户端未显式开启 thinking):
|
||||
// 1. 模型名以 -thinking 结尾(如 claude-opus-4-6-thinking):Google 要求此后缀模型必须携带 thinkingConfig。
|
||||
// 2. claude-sonnet-4-6:无 -thinking 变体(404),但模型本身要求携带 thinkingConfig;budget 必须为 -1(动态)。
|
||||
// 注:固定 budget(如 1024)在 max_tokens 较小时会触发 400(max_tokens 必须大于 budget)。
|
||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingBudget: -1, // 动态预算,避免 max_tokens vs budget 冲突
|
||||
}
|
||||
}
|
||||
|
||||
if config.MaxOutputTokens > maxLimit {
|
||||
@ -1126,65 +676,6 @@ func isWebSearchTool(tool ClaudeTool) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolConfig(toolChoice json.RawMessage, defaultValidated bool) *GeminiToolConfig {
|
||||
raw := bytes.TrimSpace(toolChoice)
|
||||
if len(raw) == 0 {
|
||||
if !defaultValidated {
|
||||
return nil
|
||||
}
|
||||
return &GeminiToolConfig{
|
||||
FunctionCallingConfig: &GeminiFunctionCallingConfig{
|
||||
Mode: "VALIDATED",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
choiceType := ""
|
||||
toolName := ""
|
||||
|
||||
if len(raw) > 0 && raw[0] == '"' {
|
||||
var choice string
|
||||
if err := json.Unmarshal(raw, &choice); err == nil {
|
||||
choiceType = strings.TrimSpace(choice)
|
||||
}
|
||||
} else {
|
||||
var choice map[string]any
|
||||
if err := json.Unmarshal(raw, &choice); err == nil {
|
||||
if value, ok := choice["type"].(string); ok {
|
||||
choiceType = strings.TrimSpace(value)
|
||||
}
|
||||
if value, ok := choice["name"].(string); ok {
|
||||
toolName = normalizeClaudeCodeToolName(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mode := ""
|
||||
switch strings.ToLower(choiceType) {
|
||||
case "auto":
|
||||
mode = "AUTO"
|
||||
case "none":
|
||||
mode = "NONE"
|
||||
case "any", "required":
|
||||
mode = "ANY"
|
||||
case "tool":
|
||||
mode = "ANY"
|
||||
case "validated":
|
||||
mode = "VALIDATED"
|
||||
default:
|
||||
if !defaultValidated {
|
||||
return nil
|
||||
}
|
||||
mode = "VALIDATED"
|
||||
}
|
||||
|
||||
cfg := &GeminiFunctionCallingConfig{Mode: mode}
|
||||
if toolName != "" && mode == "ANY" {
|
||||
cfg.AllowedFunctionNames = []string{toolName}
|
||||
}
|
||||
return &GeminiToolConfig{FunctionCallingConfig: cfg}
|
||||
}
|
||||
|
||||
// buildTools 构建 tools
|
||||
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
if len(tools) == 0 {
|
||||
@ -1215,12 +706,12 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
continue
|
||||
}
|
||||
description = tool.Custom.Description
|
||||
inputSchema = cloneStringAnyMap(tool.Custom.InputSchema)
|
||||
inputSchema = tool.Custom.InputSchema
|
||||
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
description = tool.Description
|
||||
inputSchema = cloneStringAnyMap(tool.InputSchema)
|
||||
inputSchema = tool.InputSchema
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
@ -1235,11 +726,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
description = compactToolDescriptionForAntigravity(description)
|
||||
params = compactSchemaDescriptionsForAntigravity(params)
|
||||
|
||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||
Name: normalizeClaudeCodeToolName(tool.Name),
|
||||
Name: tool.Name,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
})
|
||||
@ -1268,64 +757,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
|
||||
return declarations
|
||||
}
|
||||
|
||||
func cloneStringAnyMap(input map[string]any) map[string]any {
|
||||
if input == nil {
|
||||
return nil
|
||||
}
|
||||
if cloned, ok := deepCopy(input).(map[string]any); ok {
|
||||
return cloned
|
||||
}
|
||||
fallback := make(map[string]any, len(input))
|
||||
for key, value := range input {
|
||||
fallback[key] = value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func compactToolDescriptionForAntigravity(description string) string {
|
||||
if strings.TrimSpace(description) == "" {
|
||||
return ""
|
||||
}
|
||||
lines := strings.Split(strings.ReplaceAll(description, "\r\n", "\n"), "\n")
|
||||
compacted := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
compacted = append(compacted, line)
|
||||
if len(compacted) == 6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return truncateInlineText(strings.Join(compacted, " "), maxAntigravityToolDescriptionChars)
|
||||
}
|
||||
|
||||
func compactSchemaDescriptionsForAntigravity(schema map[string]any) map[string]any {
|
||||
for key, value := range schema {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if key == "description" {
|
||||
schema[key] = truncateInlineText(strings.Join(strings.Fields(typed), " "), maxAntigravitySchemaDescriptionChars)
|
||||
}
|
||||
case map[string]any:
|
||||
schema[key] = compactSchemaDescriptionsForAntigravity(typed)
|
||||
case []any:
|
||||
for i, item := range typed {
|
||||
if nested, ok := item.(map[string]any); ok {
|
||||
typed[i] = compactSchemaDescriptionsForAntigravity(nested)
|
||||
}
|
||||
}
|
||||
schema[key] = typed
|
||||
}
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func truncateInlineText(text string, maxChars int) string {
|
||||
if maxChars <= 0 || len(text) <= maxChars {
|
||||
return text
|
||||
}
|
||||
return text[:maxChars] + "...[truncated " + strconv.Itoa(len(text)-maxChars) + " chars]"
|
||||
}
|
||||
|
||||
@ -8,112 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureGeminiRequestSessionID(t *testing.T) {
|
||||
t.Run("prefers provided session id", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-from-header", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("keeps existing session id", func(t *testing.T) {
|
||||
body := []byte(`{"sessionId":"session-in-body","contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
updated, err := EnsureGeminiRequestSessionID(body, "session-from-header")
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(updated, &payload))
|
||||
require.Equal(t, "session-in-body", payload["sessionId"])
|
||||
})
|
||||
|
||||
t.Run("derives stable fallback from contents", func(t *testing.T) {
|
||||
body := []byte(`{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`)
|
||||
first, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
second, err := EnsureGeminiRequestSessionID(body, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
var firstPayload map[string]any
|
||||
var secondPayload map[string]any
|
||||
require.NoError(t, json.Unmarshal(first, &firstPayload))
|
||||
require.NoError(t, json.Unmarshal(second, &secondPayload))
|
||||
require.NotEmpty(t, firstPayload["sessionId"])
|
||||
require.Equal(t, firstPayload["sessionId"], secondPayload["sessionId"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDJSON(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
Metadata: &ClaudeMetadata{
|
||||
UserID: `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", req.Request.SessionID)
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_UsesMetadataSessionIDLegacy(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
Metadata: &ClaudeMetadata{
|
||||
UserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000",
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", req.Request.SessionID)
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_PrefersExplicitSessionWhenMetadataIsNotSessionPayload(t *testing.T) {
|
||||
opts := DefaultTransformOptions()
|
||||
opts.PreferredSessionID = "session-header-1"
|
||||
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
Metadata: &ClaudeMetadata{
|
||||
UserID: "custom-user-42",
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Equal(t, "session-header-1", req.Request.SessionID)
|
||||
}
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
@ -436,36 +330,16 @@ func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
// Google v1internal 要求 -thinking 模型必须携带 thinkingConfig,即使客户端明确 disabled。
|
||||
// 不携带会导致 Google 立即返回错误(在生产中表现为快速 503)。
|
||||
name: "disabled on -thinking model auto-injects thinkingConfig (Google requires it)",
|
||||
name: "disabled does not emit thinkingConfig",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024},
|
||||
wantBudget: -1, // auto-injected dynamic budget
|
||||
wantPresent: true,
|
||||
wantBudget: 0,
|
||||
wantPresent: false,
|
||||
},
|
||||
{
|
||||
// Google v1internal 要求 -thinking 模型必须携带 thinkingConfig,nil 时自动注入。
|
||||
name: "nil thinking on -thinking model auto-injects thinkingConfig (Google requires it)",
|
||||
name: "nil thinking does not emit thinkingConfig",
|
||||
model: "claude-opus-4-6-thinking",
|
||||
thinking: nil,
|
||||
wantBudget: -1, // auto-injected dynamic budget
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
// claude-sonnet-4-6 需要 thinkingConfig(无 -thinking 变体),budget 必须为 -1(动态)
|
||||
// 经测试:claude-sonnet-4-6-thinking → 404;claude-sonnet-4-6 + budget=-1 → 200 OK
|
||||
name: "nil thinking on claude-sonnet-4-6 auto-injects thinkingConfig (no -thinking variant exists)",
|
||||
model: "claude-sonnet-4-6",
|
||||
thinking: nil,
|
||||
wantBudget: -1,
|
||||
wantPresent: true,
|
||||
},
|
||||
{
|
||||
// 非 -thinking 普通模型(如 claude-opus-4-6,服务层已转为 -thinking,此处测试原始名)
|
||||
name: "nil thinking on plain non-thinking model does not emit thinkingConfig",
|
||||
model: "claude-opus-4-6",
|
||||
thinking: nil,
|
||||
wantBudget: 0,
|
||||
wantPresent: false,
|
||||
},
|
||||
@ -582,214 +456,3 @@ func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions
|
||||
require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name)
|
||||
require.NotNil(t, req.Request.Tools[1].GoogleSearch)
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_ClaudeModelKeepsToolsAndValidatedToolConfig(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"read the file"}]`),
|
||||
},
|
||||
},
|
||||
Tools: []ClaudeTool{
|
||||
{
|
||||
Name: "read_file",
|
||||
Description: "Read a file",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"file_path": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Len(t, req.Request.Tools, 1)
|
||||
require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1)
|
||||
require.Equal(t, "Read", req.Request.Tools[0].FunctionDeclarations[0].Name)
|
||||
require.NotNil(t, req.Request.ToolConfig)
|
||||
require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig)
|
||||
require.Equal(t, "VALIDATED", req.Request.ToolConfig.FunctionCallingConfig.Mode)
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_ClaudeModelToolChoiceSpecificTool(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
ToolChoice: json.RawMessage(`{"type":"tool","name":"search_files"}`),
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"find todo"}]`),
|
||||
},
|
||||
},
|
||||
Tools: []ClaudeTool{
|
||||
{
|
||||
Name: "search_files",
|
||||
Description: "Search files",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.NotNil(t, req.Request.ToolConfig)
|
||||
require.NotNil(t, req.Request.ToolConfig.FunctionCallingConfig)
|
||||
require.Equal(t, "ANY", req.Request.ToolConfig.FunctionCallingConfig.Mode)
|
||||
require.Equal(t, []string{"Grep"}, req.Request.ToolConfig.FunctionCallingConfig.AllowedFunctionNames)
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_NormalizesInterruptedToolHistory(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"tool_use","id":"tool-1","name":"Bash","input":{"command":"pwd"}},
|
||||
{"type":"text","text":"(no content)"}
|
||||
]`),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"继续"}]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "claude-sonnet-4-5", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Len(t, req.Request.Contents, 3)
|
||||
|
||||
first := req.Request.Contents[0]
|
||||
require.Equal(t, "model", first.Role)
|
||||
require.Len(t, first.Parts, 1)
|
||||
require.NotNil(t, first.Parts[0].FunctionCall)
|
||||
require.Equal(t, "tool-1", first.Parts[0].FunctionCall.ID)
|
||||
|
||||
second := req.Request.Contents[1]
|
||||
require.Equal(t, "user", second.Role)
|
||||
require.Len(t, second.Parts, 1)
|
||||
require.NotNil(t, second.Parts[0].FunctionResponse)
|
||||
require.Equal(t, "tool-1", second.Parts[0].FunctionResponse.ID)
|
||||
resultBlocks, ok := second.Parts[0].FunctionResponse.Response["result"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, resultBlocks, 1)
|
||||
resultBlock, ok := resultBlocks[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", resultBlock["type"])
|
||||
require.Equal(t, "[tool_result missing; tool execution interrupted]", resultBlock["text"])
|
||||
|
||||
third := req.Request.Contents[2]
|
||||
require.Equal(t, "user", third.Role)
|
||||
require.Len(t, third.Parts, 1)
|
||||
require.Equal(t, "继续", third.Parts[0].Text)
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesForAntigravity_ReordersThinkingAndSplitsToolResult(t *testing.T) {
|
||||
messages := []ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"text","text":"before"},
|
||||
{"type":"thinking","thinking":"deep thought","signature":"sig-1"},
|
||||
{"type":"tool_use","id":"tool-2","name":"Bash","input":{"command":"ls"}},
|
||||
{"type":"text","text":"(no content)"}
|
||||
]`),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"tool-2","content":[{"type":"text","text":"ok"}]},
|
||||
{"type":"text","text":"下一步"}
|
||||
]`),
|
||||
},
|
||||
}
|
||||
|
||||
normalized, err := normalizeClaudeMessagesForAntigravity(messages)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, normalized, 3)
|
||||
|
||||
var assistantBlocks []map[string]any
|
||||
require.NoError(t, json.Unmarshal(normalized[0].Content, &assistantBlocks))
|
||||
require.Len(t, assistantBlocks, 3)
|
||||
require.Equal(t, "thinking", assistantBlocks[0]["type"])
|
||||
require.Equal(t, "text", assistantBlocks[1]["type"])
|
||||
require.Equal(t, "tool_use", assistantBlocks[2]["type"])
|
||||
|
||||
var toolResultBlocks []map[string]any
|
||||
require.NoError(t, json.Unmarshal(normalized[1].Content, &toolResultBlocks))
|
||||
require.Len(t, toolResultBlocks, 1)
|
||||
require.Equal(t, "tool_result", toolResultBlocks[0]["type"])
|
||||
|
||||
var userTextBlocks []map[string]any
|
||||
require.NoError(t, json.Unmarshal(normalized[2].Content, &userTextBlocks))
|
||||
require.Len(t, userTextBlocks, 1)
|
||||
require.Equal(t, "text", userTextBlocks[0]["type"])
|
||||
require.Equal(t, "下一步", userTextBlocks[0]["text"])
|
||||
}
|
||||
|
||||
func TestParseToolResultContent_PreservesStructuredBlocks(t *testing.T) {
|
||||
content := json.RawMessage(`[
|
||||
{"type":"text","text":"hello"},
|
||||
{"type":"image","source":{"type":"base64","media_type":"image/png","data":"AAAA"}}
|
||||
]`)
|
||||
|
||||
result := parseToolResultContent(content, false)
|
||||
blocks, ok := result.([]map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, blocks, 1)
|
||||
require.Equal(t, "text", blocks[0]["type"])
|
||||
require.Equal(t, "hello", blocks[0]["text"])
|
||||
}
|
||||
|
||||
func TestBuildTools_CompactsDescriptions(t *testing.T) {
|
||||
longLine := strings.Repeat("schema detail ", 40)
|
||||
result := buildTools([]ClaudeTool{
|
||||
{
|
||||
Name: "describe",
|
||||
Description: strings.Repeat("tool description\n", 20),
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": longLine,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, result, 1)
|
||||
require.Len(t, result[0].FunctionDeclarations, 1)
|
||||
|
||||
decl := result[0].FunctionDeclarations[0]
|
||||
require.LessOrEqual(t, len(decl.Description), maxAntigravityToolDescriptionChars+32)
|
||||
|
||||
props, ok := decl.Parameters["properties"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
query, ok := props["query"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
description, ok := query["description"].(string)
|
||||
require.True(t, ok)
|
||||
require.LessOrEqual(t, len(description), maxAntigravitySchemaDescriptionChars+32)
|
||||
}
|
||||
|
||||
@ -121,20 +121,17 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
|
||||
p.hasToolCall = true
|
||||
|
||||
toolName := normalizeClaudeCodeToolName(part.FunctionCall.Name)
|
||||
|
||||
// 生成 tool_use id
|
||||
toolID := part.FunctionCall.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID())
|
||||
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
|
||||
}
|
||||
|
||||
item := ClaudeContentItem{
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: toolName,
|
||||
Input: part.FunctionCall.Args,
|
||||
Caller: &ToolCaller{Type: "direct"},
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: part.FunctionCall.Args,
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
|
||||
@ -362,21 +362,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu
|
||||
var result bytes.Buffer
|
||||
|
||||
p.usedTool = true
|
||||
toolName := normalizeClaudeCodeToolName(fc.Name)
|
||||
|
||||
toolID := fc.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", toolName, generateRandomID())
|
||||
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": toolID,
|
||||
"name": toolName,
|
||||
"name": fc.Name,
|
||||
"input": map[string]any{},
|
||||
"caller": map[string]any{
|
||||
"type": "direct",
|
||||
},
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
|
||||
@ -27,7 +27,7 @@ func stubUserHome(t *testing.T, home string) {
|
||||
|
||||
func TestDiscoverBinary_EnvOverrideWins(t *testing.T) {
|
||||
stubStatFn(t, map[string]bool{
|
||||
"/tmp/my-ls": true,
|
||||
"/tmp/my-ls": true,
|
||||
"/opt/windsurf/language_server_linux_x64": true, // should not be picked
|
||||
})
|
||||
got, err := discoverBinaryFor(Platform{"linux", "amd64"}, "/tmp/my-ls", "/opt/windsurf/language_server_linux_x64")
|
||||
|
||||
@ -39,8 +39,6 @@ func ProvideRouter(
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
redisClient *redis.Client,
|
||||
langServerService *service.LanguageServerService,
|
||||
lsrpcHandler *service.LSRPCHandler,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@ -97,7 +95,7 @@ func ProvideRouter(
|
||||
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
|
||||
})
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/routes"
|
||||
@ -33,8 +32,6 @@ func SetupRouter(
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
langServerService *service.LanguageServerService,
|
||||
lsrpcHandler *service.LSRPCHandler,
|
||||
) *gin.Engine {
|
||||
// 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src
|
||||
var cachedFrameOrigins atomic.Pointer[[]string]
|
||||
@ -84,7 +81,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService, lsrpcHandler)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@ -102,8 +99,6 @@ func registerRoutes(
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
langServerService *service.LanguageServerService,
|
||||
lsrpcHandler *service.LSRPCHandler,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@ -120,15 +115,5 @@ func registerRoutes(
|
||||
// Windsurf gateway routes
|
||||
routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
|
||||
// 注册 Antigravity HTTP API 路由
|
||||
routes.RegisterAntigravityHTTPRoutes(v1, langServerService)
|
||||
|
||||
// 挂载 connectrpc LanguageServerService 路由
|
||||
// Claude Code 客户端通过 /exa.language_server_pb.LanguageServerService/* 路径访问
|
||||
if lsrpcHandler != nil {
|
||||
lsrpcPath, lsrpcHTTPHandler := language_server_pbconnect.NewLanguageServerServiceHandler(lsrpcHandler)
|
||||
r.Any(lsrpcPath+"*action", gin.WrapH(lsrpcHTTPHandler))
|
||||
}
|
||||
|
||||
routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
|
||||
}
|
||||
|
||||
@ -1,192 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterAntigravityHTTPRoutes 注册 Antigravity HTTP API 路由
|
||||
func RegisterAntigravityHTTPRoutes(v1 *gin.RouterGroup, langServerService *service.LanguageServerService) {
|
||||
logger := slog.Default()
|
||||
|
||||
// 创建处理器
|
||||
cascadeGroup := v1.Group("/cascade")
|
||||
{
|
||||
// 启动 Cascade 会话
|
||||
cascadeGroup.POST("/start", func(c *gin.Context) {
|
||||
handleStartCascade(c, langServerService, logger)
|
||||
})
|
||||
|
||||
// 发送消息到 Cascade(流式响应)
|
||||
cascadeGroup.POST("/message", func(c *gin.Context) {
|
||||
handleSendMessage(c, langServerService, logger)
|
||||
})
|
||||
|
||||
// 取消 Cascade 会话
|
||||
cascadeGroup.POST("/cancel", func(c *gin.Context) {
|
||||
handleCancelCascade(c, langServerService, logger)
|
||||
})
|
||||
}
|
||||
|
||||
// 模型列表
|
||||
v1.GET("/models", func(c *gin.Context) {
|
||||
handleGetModels(c, langServerService, logger)
|
||||
})
|
||||
|
||||
// 健康检查
|
||||
v1.GET("/health", func(c *gin.Context) {
|
||||
handleHealth(c, logger)
|
||||
})
|
||||
}
|
||||
|
||||
// handleStartCascade 处理启动 Cascade 请求
|
||||
func handleStartCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
|
||||
type StartCascadeRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
var req StartCascadeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error("invalid start cascade request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取 OAuth token
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务
|
||||
cascadeID, err := svc.StartCascade(
|
||||
c.Request.Context(),
|
||||
req.Model,
|
||||
req.SystemPrompt,
|
||||
req.Metadata,
|
||||
token,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("start cascade failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"cascade_id": cascadeID})
|
||||
}
|
||||
|
||||
// handleSendMessage 处理发送消息请求(流式)
|
||||
func handleSendMessage(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
|
||||
type SendMessageRequest struct {
|
||||
CascadeID string `json:"cascade_id" binding:"required"`
|
||||
Message string `json:"message" binding:"required"`
|
||||
}
|
||||
|
||||
var req SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error("invalid send message request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取 OAuth token
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务并获取流式更新通道
|
||||
updateChan, err := svc.SendUserMessage(c.Request.Context(), req.CascadeID, req.Message, token)
|
||||
if err != nil {
|
||||
logger.Error("send message failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
// 流式发送更新到客户端
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.Error("response writer does not support flushing")
|
||||
return
|
||||
}
|
||||
|
||||
for event := range updateChan {
|
||||
if event == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// 将事件序列化为 JSON
|
||||
eventJSON, err := marshalJSON(event)
|
||||
if err != nil {
|
||||
logger.Error("failed to marshal event", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送 SSE 格式的数据
|
||||
_, _ = c.Writer.WriteString("data: " + string(eventJSON) + "\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// handleCancelCascade 处理取消 Cascade 请求
|
||||
func handleCancelCascade(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
|
||||
type CancelRequest struct {
|
||||
CascadeID string `json:"cascade_id" binding:"required"`
|
||||
}
|
||||
|
||||
var req CancelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Error("invalid cancel request", "error", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
err := svc.CancelCascade(c.Request.Context(), req.CascadeID)
|
||||
if err != nil {
|
||||
logger.Error("cancel cascade failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "cascade cancelled"})
|
||||
}
|
||||
|
||||
// handleGetModels 处理获取模型列表请求
|
||||
func handleGetModels(c *gin.Context, svc *service.LanguageServerService, logger *slog.Logger) {
|
||||
models, err := svc.GetAvailableModels(c.Request.Context())
|
||||
if err != nil {
|
||||
logger.Error("get models failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"models": models,
|
||||
"default_model": "claude-opus-4-6",
|
||||
})
|
||||
}
|
||||
|
||||
// handleHealth 处理健康检查请求
|
||||
func handleHealth(c *gin.Context, logger *slog.Logger) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
||||
}
|
||||
|
||||
// marshalJSON 辅助函数用于序列化事件
|
||||
func marshalJSON(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
@ -1,365 +0,0 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func TestAntigravityHTTPRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 创建模拟的 LanguageServerService
|
||||
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
|
||||
defer mockService.Stop()
|
||||
|
||||
// 创建路由
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册 Antigravity 路由
|
||||
RegisterAntigravityHTTPRoutes(v1, mockService)
|
||||
|
||||
// 测试 1: GET /health
|
||||
t.Run("HealthCheck", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/api/v1/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
if result["status"] != "healthy" {
|
||||
t.Fatalf("Expected status=healthy, got %v", result)
|
||||
}
|
||||
t.Log("✅ 健康检查端点")
|
||||
})
|
||||
|
||||
// 测试 2: GET /models
|
||||
t.Run("GetModels", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/api/v1/models", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
if result["default_model"] != "claude-opus-4-6" {
|
||||
t.Fatalf("Expected default_model, got %v", result)
|
||||
}
|
||||
t.Log("✅ 获取模型列表")
|
||||
})
|
||||
|
||||
// 测试 3: POST /cascade/start
|
||||
var cascadeID string
|
||||
t.Run("StartCascade", func(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"model": "claude-opus-4-6",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
cascadeID = result["cascade_id"]
|
||||
if cascadeID == "" {
|
||||
t.Fatalf("Expected cascade_id, got empty")
|
||||
}
|
||||
t.Logf("✅ 启动会话 (cascade_id=%s)", cascadeID)
|
||||
})
|
||||
|
||||
// 测试 4: POST /cascade/cancel(使用从第3个测试获取的真实会话ID)
|
||||
t.Run("CancelCascade", func(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cascade_id": cascadeID,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/cancel", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
if result["message"] != "cascade cancelled" {
|
||||
t.Fatalf("Expected cascade cancelled message, got %v", result)
|
||||
}
|
||||
t.Log("✅ 取消会话")
|
||||
})
|
||||
|
||||
// 测试 5: POST /cascade/message (SSE) - 验证响应头格式
|
||||
t.Run("SendMessage", func(t *testing.T) {
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cascade_id": cascadeID,
|
||||
"message": "Hello, world!",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "text/event-stream" {
|
||||
t.Fatalf("Expected text/event-stream, got %s", contentType)
|
||||
}
|
||||
t.Log("✅ 发送消息(SSE流式响应)")
|
||||
})
|
||||
|
||||
t.Log("\n✅ 所有 Antigravity HTTP API 路由测试通过!")
|
||||
}
|
||||
|
||||
func TestStartCascadeValidation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
|
||||
defer mockService.Stop()
|
||||
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
RegisterAntigravityHTTPRoutes(v1, mockService)
|
||||
|
||||
t.Run("MissingModel", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
body := []byte(`{"system_prompt":"test"}`)
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected 400 for invalid request, got %d", w.Code)
|
||||
}
|
||||
t.Log("✅ 缺少必需字段验证")
|
||||
})
|
||||
|
||||
t.Run("MissingAuthorization", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
body := []byte(`{"model":"claude-opus-4-6"}`)
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// 不设置 Authorization 头
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected 401 for missing auth, got %d", w.Code)
|
||||
}
|
||||
t.Log("✅ 缺少授权令牌验证")
|
||||
})
|
||||
|
||||
t.Log("\n✅ 所有验证测试通过!")
|
||||
}
|
||||
|
||||
// TestRateLimiting 测试速率限制(改进 1)
|
||||
func TestRateLimiting(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
|
||||
defer mockService.Stop()
|
||||
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
RegisterAntigravityHTTPRoutes(v1, mockService)
|
||||
|
||||
// 创建一个会话
|
||||
startBody, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(startBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var startResult map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &startResult)
|
||||
cascadeID := startResult["cascade_id"]
|
||||
|
||||
// 并发发送 150 个消息,应该有的超过限制
|
||||
var wg sync.WaitGroup
|
||||
results := make([]int, 0)
|
||||
var resultsMutex sync.Mutex
|
||||
|
||||
for i := 0; i < 150; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cascade_id": cascadeID,
|
||||
"message": "Test message " + string(rune(idx)),
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
resultsMutex.Lock()
|
||||
results = append(results, w.Code)
|
||||
resultsMutex.Unlock()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 统计结果
|
||||
successCount := 0
|
||||
timeoutCount := 0
|
||||
for _, code := range results {
|
||||
if code == 200 || code == 500 { // 500 可能是上游 API 错误
|
||||
successCount++
|
||||
} else if code == 504 { // 网关超时
|
||||
timeoutCount++
|
||||
}
|
||||
}
|
||||
|
||||
// 预期:大部分请求成功(因为有速率限制),但速率限制应该生效
|
||||
// 限制是 100 并发,所以 150 个请求中应该都能处理(只是可能有等待)
|
||||
if successCount < 140 {
|
||||
t.Logf("⚠️ 仅 %d/150 个请求成功(超过限制被拒绝)- 这是预期的速率限制行为", successCount)
|
||||
}
|
||||
|
||||
t.Logf("✅ 速率限制测试完成:成功=%d, 超时=%d", successCount, timeoutCount)
|
||||
}
|
||||
|
||||
// TestSessionCleanup 测试会话超时清理(改进 3)
|
||||
func TestSessionCleanup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
|
||||
mockService.SetSessionTTL(2) // 设置 2 秒过期,便于测试
|
||||
defer mockService.Stop()
|
||||
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
RegisterAntigravityHTTPRoutes(v1, mockService)
|
||||
|
||||
// 创建 5 个会话
|
||||
cascadeIDs := make([]string, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var result map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
cascadeIDs[i] = result["cascade_id"]
|
||||
}
|
||||
|
||||
// 验证所有会话存在
|
||||
sessions := mockService.GetCascadeSessions()
|
||||
if len(sessions) != 5 {
|
||||
t.Fatalf("Expected 5 sessions, got %d", len(sessions))
|
||||
}
|
||||
t.Log("✅ 创建了 5 个会话")
|
||||
|
||||
// 等待清理周期 + TTL
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// 验证会话被清理
|
||||
sessions = mockService.GetCascadeSessions()
|
||||
sessionCount := len(sessions)
|
||||
|
||||
if sessionCount != 0 {
|
||||
t.Logf("⚠️ 预期 0 个会话,但仍有 %d 个(可能清理还未执行)", sessionCount)
|
||||
} else {
|
||||
t.Log("✅ 过期会话成功清理")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentMessageAppend 测试并发安全的消息追加(改进 2)
|
||||
func TestConcurrentMessageAppend(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockService := service.NewLanguageServerService(slog.Default(), nil, nil, nil)
|
||||
defer mockService.Stop()
|
||||
|
||||
r := gin.New()
|
||||
v1 := r.Group("/api/v1")
|
||||
RegisterAntigravityHTTPRoutes(v1, mockService)
|
||||
|
||||
// 创建会话
|
||||
body, _ := json.Marshal(map[string]string{"model": "claude-opus-4-6"})
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/start", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
var result map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
cascadeID := result["cascade_id"]
|
||||
|
||||
// 并发追加 50 个消息
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cascade_id": cascadeID,
|
||||
"message": "Concurrent message " + string(rune(idx)),
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/cascade/message", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 不关心返回值,只关心不 panic
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证会话中的消息数量
|
||||
sessions := mockService.GetCascadeSessions()
|
||||
messageCount := 0
|
||||
if session, exists := sessions[cascadeID]; exists {
|
||||
messageCount = len(session.Messages)
|
||||
}
|
||||
|
||||
// 预期:1 个初始消息(如果没有 system_prompt,则为 0)+ 最多 50 个用户消息
|
||||
// 但由于速率限制,可能不是所有 50 个都会被处理
|
||||
if messageCount > 0 {
|
||||
t.Logf("✅ 并发消息追加成功,共 %d 条消息", messageCount)
|
||||
} else {
|
||||
t.Log("⚠️ 由于速率限制或其他原因,部分消息未被追加")
|
||||
}
|
||||
}
|
||||
@ -1,254 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestAccount68FullE2E 测试账号 68 的完整端到端流程
|
||||
// 模拟: curl POST /api/v1/admin/accounts/68/test
|
||||
func TestAccount68FullE2E(t *testing.T) {
|
||||
t.Log("🔥 测试账号 68 的完整认证流程...")
|
||||
t.Log("")
|
||||
|
||||
// 准备账号数据(与云端数据一致)
|
||||
account := &Account{
|
||||
ID: 68,
|
||||
Name: "PriesJosephe139@gmail.com",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: "oauth",
|
||||
Credentials: map[string]interface{}{
|
||||
"_token_version": 1775902256706,
|
||||
"access_token": "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213",
|
||||
"email": "priesjosephe139@gmail.com",
|
||||
"expires_at": "1775907556",
|
||||
"model_mapping": map[string]interface{}{
|
||||
"claude-opus-*": "claude-opus-4-6-thinking",
|
||||
"claude-sonnet-*": "claude-sonnet-4-6-thinking",
|
||||
},
|
||||
"plan_type": "Free",
|
||||
"project_id": "kinetic-sum-r3tp7",
|
||||
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
Extra: map[string]interface{}{
|
||||
"allow_overages": true,
|
||||
"privacy_mode": "privacy_set",
|
||||
},
|
||||
ProxyID: ptrInt64(9),
|
||||
Concurrency: 100,
|
||||
Priority: 1,
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
t.Log("📌 账号信息:")
|
||||
t.Logf(" ID: %d", account.ID)
|
||||
t.Logf(" Name: %s", account.Name)
|
||||
t.Logf(" Platform: %s", account.Platform)
|
||||
t.Logf(" Project ID: %v", account.GetCredential("project_id"))
|
||||
t.Log("")
|
||||
|
||||
// 步骤 1: 验证凭证
|
||||
t.Run("Step1_ValidateCredentials", func(t *testing.T) {
|
||||
t.Log("步骤 1: 验证账号凭证...")
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
t.Fatalf("❌ Access token 为空")
|
||||
}
|
||||
t.Logf(" ✓ Access Token 存在 (长度: %d)", len(accessToken))
|
||||
|
||||
projectID := account.GetCredential("project_id")
|
||||
if projectID == "" {
|
||||
t.Fatalf("❌ Project ID 为空")
|
||||
}
|
||||
t.Logf(" ✓ Project ID 存在: %s", projectID)
|
||||
|
||||
t.Log("")
|
||||
})
|
||||
|
||||
// 步骤 2: 测试 API 调用(通过 SOCKS5 代理)
|
||||
t.Run("Step2_CallUpstreamAPI", func(t *testing.T) {
|
||||
t.Log("步骤 2: 通过 SOCKS5 代理调用上游 API...")
|
||||
t.Log("")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30)
|
||||
defer cancel()
|
||||
|
||||
// 使用之前测试过的配置
|
||||
proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760"
|
||||
accessTokenStr := account.GetCredential("access_token")
|
||||
|
||||
t.Logf(" 📤 API 请求:")
|
||||
t.Logf(" URL: https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist")
|
||||
t.Logf(" Token: %s... (长度: %d)", accessTokenStr[:30], len(accessTokenStr))
|
||||
t.Logf(" Proxy: %s", proxyAddr)
|
||||
t.Log("")
|
||||
|
||||
// 创建 HTTP 客户端(使用 SOCKS5 代理)
|
||||
transport := &http.Transport{}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST",
|
||||
"https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist",
|
||||
bytes.NewReader([]byte(`{}`)))
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessTokenStr)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Logf("❌ API 调用失败: %v", err)
|
||||
t.Logf(" (可能是网络问题,但凭证本身没问题)")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
t.Logf(" ✓ 收到响应")
|
||||
t.Logf(" HTTP Status: %d", resp.StatusCode)
|
||||
t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type"))
|
||||
t.Log("")
|
||||
|
||||
// 读取响应
|
||||
respBody := make([]byte, 2048)
|
||||
n, _ := resp.Body.Read(respBody)
|
||||
respText := string(respBody[:n])
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
t.Log(" ✅ API 调用成功!")
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody[:n], &result); err == nil {
|
||||
if _, ok := result["cloudaicompanionProject"]; ok {
|
||||
t.Logf(" ✓ 获得 Project: %v", result["cloudaicompanionProject"])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Logf(" ❌ API 返回错误 (HTTP %d)", resp.StatusCode)
|
||||
t.Logf(" 响应: %s", respText)
|
||||
}
|
||||
t.Log("")
|
||||
})
|
||||
|
||||
// 步骤 3: 模拟 SSE 响应流(本地)
|
||||
t.Run("Step3_SimulateSSEResponse", func(t *testing.T) {
|
||||
t.Log("步骤 3: 模拟 SSE 响应流...")
|
||||
t.Log("")
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 模拟成功的 API 响应
|
||||
successResponse := map[string]interface{}{
|
||||
"cloudaicompanionProject": "kinetic-sum-r3tp7",
|
||||
"currentTier": map[string]interface{}{
|
||||
"id": "free-tier",
|
||||
"name": "Antigravity",
|
||||
},
|
||||
}
|
||||
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
// 设置 SSE 头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Status(200)
|
||||
|
||||
// 发送测试开始
|
||||
event1 := map[string]interface{}{
|
||||
"type": "test_start",
|
||||
"model": "claude-opus-4-6",
|
||||
}
|
||||
data1, _ := json.Marshal(event1)
|
||||
c.Writer.WriteString("data: " + string(data1) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
// 发送内容(成功的 API 响应)
|
||||
event2 := map[string]interface{}{
|
||||
"type": "content",
|
||||
"text": "✅ 账号验证成功!",
|
||||
}
|
||||
data2, _ := json.Marshal(event2)
|
||||
c.Writer.WriteString("data: " + string(data2) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
// 发送完成
|
||||
event3 := map[string]interface{}{
|
||||
"type": "test_complete",
|
||||
"success": true,
|
||||
}
|
||||
data3, _ := json.Marshal(event3)
|
||||
c.Writer.WriteString("data: " + string(data3) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
t.Logf(" 📤 服务器已发送 SSE 事件:")
|
||||
t.Logf(" 1. test_start (model=%v)", successResponse["cloudaicompanionProject"])
|
||||
t.Logf(" 2. content (text: ✅ 账号验证成功!)")
|
||||
t.Logf(" 3. test_complete (success=true)")
|
||||
})
|
||||
|
||||
// 发送请求
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte(`{}`)))
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证响应
|
||||
t.Log("")
|
||||
t.Log(" 📥 客户端收到的响应:")
|
||||
body := w.Body.String()
|
||||
lines := bytes.Split([]byte(body), []byte("\n\n"))
|
||||
for i, line := range lines {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err == nil {
|
||||
t.Logf(" 事件 %d: type=%v", i, event["type"])
|
||||
if content, ok := event["content"]; ok {
|
||||
t.Logf(" content=%v", content)
|
||||
}
|
||||
if success, ok := event["success"]; ok {
|
||||
t.Logf(" success=%v", success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Log("")
|
||||
})
|
||||
|
||||
// 步骤 4: 总结
|
||||
t.Run("Step4_Summary", func(t *testing.T) {
|
||||
t.Log("步骤 4: 总结...")
|
||||
t.Log("")
|
||||
t.Log("✅ 账号 68 测试完成!")
|
||||
t.Log("")
|
||||
t.Log("🎯 关键发现:")
|
||||
t.Log(" 1. Access Token 已刷新成功 ✅")
|
||||
t.Log(" 2. Project ID 有效: kinetic-sum-r3tp7 ✅")
|
||||
t.Log(" 3. 上游 Google API 返回 200 成功 ✅")
|
||||
t.Log(" 4. SSE 事件正确传递 ✅")
|
||||
t.Log("")
|
||||
t.Log("📊 预期结果:")
|
||||
t.Log(" - 云端测试应该也能成功")
|
||||
t.Log(" - 不再看到 'IT' 错误")
|
||||
t.Log("")
|
||||
})
|
||||
}
|
||||
|
||||
func ptrInt64(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
@ -1,91 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// TestDirectUpstreamCall 直接调用真实的 Google API,看返回什么
|
||||
func TestDirectUpstreamCall(t *testing.T) {
|
||||
t.Log("🔥 直接调用 Google API,观察真实返回值...")
|
||||
t.Log("")
|
||||
|
||||
accessToken := "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 步骤 1: 创建客户端
|
||||
t.Log("步骤 1: 创建 Antigravity 客户端...")
|
||||
client, err := antigravity.NewClient("")
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 创建客户端失败: %v", err)
|
||||
}
|
||||
t.Log("✅ 客户端创建成功")
|
||||
t.Log("")
|
||||
|
||||
// 步骤 2: 直接调用 LoadCodeAssist
|
||||
t.Log("步骤 2: 调用 client.LoadCodeAssist(ctx, accessToken)...")
|
||||
t.Logf(" AccessToken: %s... (长度: %d)", accessToken[:30], len(accessToken))
|
||||
t.Log("")
|
||||
|
||||
resp, rawResp, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
// 步骤 3: 分析返回值
|
||||
t.Log("步骤 3: 分析返回值...")
|
||||
t.Log("")
|
||||
|
||||
if err != nil {
|
||||
t.Logf("❌ 调用失败")
|
||||
t.Logf(" 错误类型: %T", err)
|
||||
t.Logf(" 错误信息: %v", err)
|
||||
t.Logf(" 错误字符串: %s", err.Error())
|
||||
t.Logf(" 错误长度: %d 字符", len(err.Error()))
|
||||
t.Log("")
|
||||
|
||||
// 分析错误信息的前几个字符
|
||||
errStr := err.Error()
|
||||
if len(errStr) >= 2 {
|
||||
t.Logf("📊 错误信息的前 5 个字符: '%s'", errStr[:min(5, len(errStr))])
|
||||
}
|
||||
t.Log("")
|
||||
|
||||
t.Logf("🎯 这就是导致 'IT' 错误的真实原因!")
|
||||
t.Logf(" 错误完整内容: %q", errStr)
|
||||
t.Log("")
|
||||
|
||||
// 尝试找出 "IT" 的来源
|
||||
if len(errStr) >= 2 {
|
||||
first2 := errStr[:2]
|
||||
t.Logf("📌 错误的前两个字符: '%s'", first2)
|
||||
if first2 == "IT" {
|
||||
t.Logf(" ✓ 确认: 'IT' 就是从这个错误截断来的")
|
||||
} else {
|
||||
t.Logf(" ⚠️ 前两个字符不是 'IT',可能被其他方式处理了")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 成功的情况
|
||||
t.Log("✅ 调用成功!")
|
||||
t.Log("")
|
||||
|
||||
if resp != nil {
|
||||
t.Logf("📋 响应信息:")
|
||||
t.Logf(" CloudAICompanionProject: %s", resp.CloudAICompanionProject)
|
||||
t.Logf(" Response 类型: %T", resp)
|
||||
t.Log("")
|
||||
|
||||
// 打印原始响应
|
||||
if rawResp != nil {
|
||||
t.Log("📄 原始 API 响应 JSON:")
|
||||
jsonBytes, _ := json.MarshalIndent(rawResp, " ", " ")
|
||||
t.Logf("%s", string(jsonBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -44,10 +44,9 @@ const (
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
|
||||
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
|
||||
// 使用指数退避策略重试,最多重试 10 次(而非 60 次)
|
||||
antigravityModelCapacityRetryMaxAttempts = 10
|
||||
// 使用固定 1s 间隔重试,最多重试 60 次
|
||||
antigravityModelCapacityRetryMaxAttempts = 60
|
||||
antigravityModelCapacityRetryWait = 1 * time.Second
|
||||
antigravityModelCapacityRetryMaxWait = 32 * time.Second // 指数退避上限
|
||||
|
||||
// Google RPC 状态和类型常量
|
||||
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
||||
@ -113,62 +112,6 @@ func IsAntigravityAccountSwitchError(err error) (*AntigravityAccountSwitchError,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func isGoogleOneAICreditsEntry(entry map[string]any) bool {
|
||||
creditType, _ := firstPresent(entry, "CreditType", "credit_type", "creditType").(string)
|
||||
creditType = strings.TrimSpace(strings.ToUpper(creditType))
|
||||
return creditType == "" || creditType == "GOOGLE_ONE_AI"
|
||||
}
|
||||
|
||||
func firstPresent(entry map[string]any, keys ...string) any {
|
||||
for _, key := range keys {
|
||||
if value, ok := entry[key]; ok {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseAICreditsInt32(raw any) (int32, bool) {
|
||||
switch v := raw.(type) {
|
||||
case int:
|
||||
return int32(v), true
|
||||
case int32:
|
||||
return v, true
|
||||
case int64:
|
||||
return int32(v), true
|
||||
case float32:
|
||||
return int32(v), true
|
||||
case float64:
|
||||
return int32(v), true
|
||||
case json.Number:
|
||||
parsed, err := v.Int64()
|
||||
if err != nil {
|
||||
floatVal, floatErr := strconv.ParseFloat(v.String(), 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
}
|
||||
return int32(parsed), true
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsed, err := strconv.ParseInt(trimmed, 10, 32)
|
||||
if err == nil {
|
||||
return int32(parsed), true
|
||||
}
|
||||
floatVal, floatErr := strconv.ParseFloat(trimmed, 64)
|
||||
if floatErr != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int32(floatVal), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
@ -206,28 +149,17 @@ type antigravityRetryLoopResult struct {
|
||||
}
|
||||
|
||||
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
|
||||
// 根据账号类型选择优先 URL:企业账号(isGcpTos=true)→ prod;个人账号 → daily(与真实 IDE 一致)。
|
||||
// 可通过环境变量 GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL=daily 或 =prod 强制覆盖。
|
||||
func resolveAntigravityForwardBaseURL(account *Account) string {
|
||||
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
|
||||
if mode == "daily" {
|
||||
return "https://daily-cloudcode-pa.googleapis.com"
|
||||
}
|
||||
if mode == "prod" {
|
||||
return "https://cloudcode-pa.googleapis.com"
|
||||
}
|
||||
// 按账号类型选择优先 URL
|
||||
isGcpTos := account != nil && account.GetCredentialAsBool("is_gcp_tos")
|
||||
urls := antigravity.BaseURLsForAccount(isGcpTos)
|
||||
if len(urls) == 0 {
|
||||
// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
|
||||
func resolveAntigravityForwardBaseURL() string {
|
||||
baseURLs := antigravity.ForwardBaseURLs()
|
||||
if len(baseURLs) == 0 {
|
||||
return ""
|
||||
}
|
||||
// 返回可用列表中的第一个(URLAvailability 动态优先级在调用方处理)
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(urls)
|
||||
if len(available) > 0 {
|
||||
return available[0]
|
||||
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
|
||||
if mode == "prod" && len(baseURLs) > 1 {
|
||||
return baseURLs[1]
|
||||
}
|
||||
return urls[0]
|
||||
return baseURLs[0]
|
||||
}
|
||||
|
||||
// smartRetryAction 智能重试的处理结果
|
||||
@ -319,7 +251,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
var lastRetryResp *http.Response
|
||||
var lastRetryBody []byte
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(10 次,指数退避)
|
||||
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔)
|
||||
maxAttempts := antigravitySmartRetryMaxAttempts
|
||||
if isModelCapacityExhausted {
|
||||
maxAttempts = antigravityModelCapacityRetryMaxAttempts
|
||||
@ -346,29 +278,10 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
// 计算本次重试的等待时间
|
||||
var currentWaitDuration time.Duration
|
||||
if isModelCapacityExhausted {
|
||||
// 使用指数退避:1s, 2s, 4s, 8s, 16s, 32s, ...
|
||||
currentWaitDuration = waitDuration * time.Duration(1<<(attempt-1))
|
||||
if currentWaitDuration > antigravityModelCapacityRetryMaxWait {
|
||||
currentWaitDuration = antigravityModelCapacityRetryMaxWait
|
||||
}
|
||||
// 添加随机抖动(±10%)避免羊群效应
|
||||
jitter := time.Duration(mathrand.Int63n(int64(currentWaitDuration / 5)))
|
||||
if mathrand.Intn(2) == 0 {
|
||||
currentWaitDuration += jitter
|
||||
} else {
|
||||
currentWaitDuration -= jitter
|
||||
}
|
||||
} else {
|
||||
currentWaitDuration = waitDuration
|
||||
}
|
||||
|
||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, maxAttempts, currentWaitDuration, modelName, p.account.ID)
|
||||
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
|
||||
|
||||
timer := time.NewTimer(currentWaitDuration)
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
timer.Stop()
|
||||
@ -678,7 +591,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := resolveAntigravityForwardBaseURL(p.account)
|
||||
baseURL := resolveAntigravityForwardBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("no antigravity forward base url configured")
|
||||
}
|
||||
@ -1084,20 +997,13 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
|
||||
return mapAntigravityModel(account, requestedModel)
|
||||
}
|
||||
|
||||
// applyThinkingModelSuffix 根据 thinking 配置和模型可用性调整模型名。
|
||||
// Google v1internal API 上部分 Claude 模型只有 -thinking 后缀版本存在,
|
||||
// 非 -thinking 版本会返回 404。
|
||||
// applyThinkingModelSuffix 根据 thinking 配置调整模型名
|
||||
// 当映射结果是 claude-sonnet-4-5 且请求开启了 thinking 时,改为 claude-sonnet-4-5-thinking
|
||||
func applyThinkingModelSuffix(mappedModel string, thinkingEnabled bool) string {
|
||||
// claude-opus-4-6: Google API 上只有 -thinking 版本,始终加后缀
|
||||
if mappedModel == "claude-opus-4-6" {
|
||||
return "claude-opus-4-6-thinking"
|
||||
}
|
||||
// 其他模型仅在 thinking 开启时加后缀
|
||||
if !thinkingEnabled {
|
||||
return mappedModel
|
||||
}
|
||||
switch mappedModel {
|
||||
case "claude-sonnet-4-5":
|
||||
if mappedModel == "claude-sonnet-4-5" {
|
||||
return "claude-sonnet-4-5-thinking"
|
||||
}
|
||||
return mappedModel
|
||||
@ -1139,10 +1045,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
return nil, fmt.Errorf("model %s not in whitelist", modelID)
|
||||
}
|
||||
|
||||
// 应用 thinking 后缀(claude-opus-4-6 → claude-opus-4-6-thinking)
|
||||
// TestConnection 与主请求路径保持一致:Google API 只支持 -thinking 后缀版本的部分模型
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, false)
|
||||
|
||||
// 构建请求体
|
||||
var requestBody []byte
|
||||
if strings.HasPrefix(modelID, "gemini-") {
|
||||
@ -1246,17 +1148,17 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
|
||||
}
|
||||
|
||||
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
|
||||
// 使用最小 token 消耗:输入 "." + MaxTokens: 10(足够验证连接)
|
||||
// 使用最小 token 消耗:输入 "." + MaxTokens: 1
|
||||
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
|
||||
claudeReq := &antigravity.ClaudeRequest{
|
||||
Model: mappedModel,
|
||||
Messages: []antigravity.ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`"Test connection"`),
|
||||
Content: json.RawMessage(`"."`),
|
||||
},
|
||||
},
|
||||
MaxTokens: 10,
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
}
|
||||
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
|
||||
@ -1387,19 +1289,9 @@ func injectIdentityPatchToGeminiRequest(body []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte, preferredSessionID ...string) ([]byte, error) {
|
||||
sessionID := ""
|
||||
if len(preferredSessionID) > 0 {
|
||||
sessionID = preferredSessionID[0]
|
||||
}
|
||||
|
||||
bodyWithSessionID, err := antigravity.EnsureGeminiRequestSessionID(originalBody, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("补全 sessionId 失败: %w", err)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||||
var request any
|
||||
if err := json.Unmarshal(bodyWithSessionID, &request); err != nil {
|
||||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||||
}
|
||||
|
||||
@ -1477,6 +1369,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if mappedModel == "" {
|
||||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||||
}
|
||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||
billingModel := mappedModel
|
||||
@ -1503,9 +1396,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 获取转换选项
|
||||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||
transformOpts.EnableIdentityPatch = true
|
||||
transformOpts.PreferredSessionID = sessionID
|
||||
transformOpts.EnableIdentityPatch = true // 强制启用,Antigravity 上游必需
|
||||
|
||||
// 转换 Claude 请求为 Gemini 格式
|
||||
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
|
||||
@ -1513,8 +1406,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request")
|
||||
}
|
||||
|
||||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||
action := "streamGenerateContent"
|
||||
|
||||
// 执行带重试的请求
|
||||
result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
@ -1529,17 +1425,19 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
isStickySession: isStickySession, // Forward 由上层判断粘性会话
|
||||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否是账号切换信号,转换为 UpstreamFailoverError 让 Handler 切换账号
|
||||
if switchErr, ok := IsAntigravityAccountSwitchError(err); ok {
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
ForceCacheBilling: switchErr.IsStickySession,
|
||||
}
|
||||
}
|
||||
// 区分客户端取消和真正的上游失败,返回更准确的错误消息
|
||||
if c.Request.Context().Err() != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response")
|
||||
}
|
||||
@ -1551,6 +1449,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
@ -1567,6 +1468,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable top-level thinking + thinking->text
|
||||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||||
|
||||
retryStages := []struct {
|
||||
name string
|
||||
strip func(*antigravity.ClaudeRequest) (bool, error)
|
||||
@ -1586,7 +1491,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, transformOpts)
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx))
|
||||
if txErr != nil {
|
||||
continue
|
||||
}
|
||||
@ -1605,8 +1510,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
groupID: 0, // Forward 方法没有 groupID,由上层处理粘性会话清除
|
||||
sessionHash: "", // Forward 方法没有 sessionHash,由上层处理粘性会话清除
|
||||
})
|
||||
if retryErr != nil {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
@ -1659,6 +1564,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Detail: retryUpstreamDetail,
|
||||
})
|
||||
|
||||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
@ -1669,6 +1575,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
break
|
||||
}
|
||||
|
||||
// Still signature-related; capture context and allow next stage.
|
||||
respBody = retryBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
@ -1678,7 +1585,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Budget 整流
|
||||
// Budget 整流:检测 budget_tokens 约束错误并自动修正重试
|
||||
if resp.StatusCode == http.StatusBadRequest && respBody != nil && !isSignatureRelatedError(respBody) {
|
||||
errMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) {
|
||||
@ -1693,9 +1600,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Detail: s.getUpstreamErrorDetail(respBody),
|
||||
})
|
||||
|
||||
// 修正 claudeReq 的 thinking 参数(adaptive 模式不修正)
|
||||
if claudeReq.Thinking == nil || claudeReq.Thinking.Type != "adaptive" {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
// 创建新的 ThinkingConfig 避免修改原始 claudeReq.Thinking 指针
|
||||
retryClaudeReq.Thinking = &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: BudgetRectifyBudgetTokens,
|
||||
@ -1750,7 +1659,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 检测 prompt too long 错误,返回特殊错误类型供上层 fallback
|
||||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
@ -1778,6 +1689,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
||||
|
||||
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg) {
|
||||
@ -1828,6 +1740,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if claudeReq.Stream {
|
||||
// 客户端要求流式,直接透传转换
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err)
|
||||
@ -1837,6 +1750,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
clientDisconnect = streamRes.clientDisconnect
|
||||
} else {
|
||||
// 客户端要求非流式,收集流式响应后转换返回
|
||||
streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err)
|
||||
@ -1846,13 +1760,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
}
|
||||
|
||||
// DEBUG: 追踪 OAuth Claude 路径的 Usage 在 Forward 返回点的值。
|
||||
// 若这里 output>0 而 DB 记录为 0,说明 bug 在下游(billing/record 层);
|
||||
// 若这里 output=0,说明 bug 在 handleClaudeStreamingResponse 或更上游。
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"%s DEBUG_USAGE_FORWARD_RETURN input=%d output=%d cache_read=%d cache_creation=%d stream=%v model=%s account=%d",
|
||||
prefix, usage.InputTokens, usage.OutputTokens, usage.CacheReadInputTokens, usage.CacheCreationInputTokens, claudeReq.Stream, originalModel, account.ID)
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@ -2249,7 +2156,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody, sessionID)
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||
if err != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build upstream request")
|
||||
}
|
||||
@ -2313,7 +2220,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody, sessionID)
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
@ -2356,7 +2263,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||
|
||||
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody, sessionID)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||
if wrapErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
@ -3124,45 +3031,6 @@ func handleStreamReadError(err error, clientDisconnected bool, prefix string) (d
|
||||
return false, false
|
||||
}
|
||||
|
||||
func googleStatusTextForHTTP(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "INVALID_ARGUMENT"
|
||||
case http.StatusNotFound:
|
||||
return "NOT_FOUND"
|
||||
case http.StatusTooManyRequests:
|
||||
return "RESOURCE_EXHAUSTED"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "UNAVAILABLE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
func buildAnthropicStreamErrorEvent(errType, message string) string {
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return "event: error\ndata: " + string(data) + "\n\n"
|
||||
}
|
||||
|
||||
func buildGeminiStreamErrorEvent(status int, message string) string {
|
||||
payload := map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleStatusTextForHTTP(status),
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return "event: error\ndata: " + string(data) + "\n\n"
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@ -3258,12 +3126,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(status int, message string) {
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprint(c.Writer, buildGeminiStreamErrorEvent(status, message))
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@ -3279,10 +3147,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent(http.StatusBadGateway, "Response too large")
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream read failed")
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, ev.err
|
||||
}
|
||||
|
||||
@ -3345,7 +3213,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent(http.StatusServiceUnavailable, "Upstream stream timeout")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
@ -4105,12 +3973,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(errType, message string) {
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || cw.Disconnected() {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprint(c.Writer, buildAnthropicStreamErrorEvent(errType, message))
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
@ -4126,9 +3994,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if !ok {
|
||||
// 上游完成,发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
logger.LegacyPrintf("service.antigravity_gateway",
|
||||
"DEBUG_USAGE_PROCESSOR_FINISH input=%d output=%d cache_read=%d image_output=%d final_events_len=%d",
|
||||
agUsage.InputTokens, agUsage.OutputTokens, agUsage.CacheReadInputTokens, agUsage.ImageOutputTokens, len(finalEvents))
|
||||
if len(finalEvents) > 0 {
|
||||
cw.Write(finalEvents)
|
||||
} else if !processor.MessageStartSent() && !cw.Disconnected() {
|
||||
@ -4145,15 +4010,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
if ev.err != nil {
|
||||
if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=handleStreamReadError disconnect=%v", disconnect)
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "DEBUG_USAGE_CLAUDE_STREAM_EARLY_RETURN path=ErrTooLong max_size=%d error=%v (usage WILL BE ZEROED)", maxLineSize, ev.err)
|
||||
sendErrorEvent("api_error", "Response too large")
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("api_error", "Upstream stream read failed")
|
||||
sendErrorEvent("stream_read_error")
|
||||
return nil, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
@ -4179,7 +4043,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)")
|
||||
sendErrorEvent("api_error", "Upstream stream timeout")
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
@ -4672,61 +4536,3 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ForwardRaw 转发 Claude 格式请求并返回原始上游响应体(调用者负责关闭)。
|
||||
// 不依赖 gin.Context,供内部服务(如 LanguageServerService)调用。
|
||||
// 复用完整的 token 刷新、模型映射、TLS 指纹和重试逻辑。
|
||||
func (s *AntigravityGatewayService) ForwardRaw(ctx context.Context, account *Account, body []byte) (io.ReadCloser, int, error) {
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("missing model")
|
||||
}
|
||||
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
if mappedModel == "" {
|
||||
return nil, http.StatusForbidden, fmt.Errorf("model %s not in whitelist", claudeReq.Model)
|
||||
}
|
||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||
|
||||
if s.tokenProvider == nil {
|
||||
return nil, http.StatusBadGateway, fmt.Errorf("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadGateway, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||
transformOpts.EnableIdentityPatch = true
|
||||
geminiBody, err := antigravity.TransformClaudeToGeminiWithOptions(&claudeReq, projectID, mappedModel, transformOpts)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("failed to transform request: %w", err)
|
||||
}
|
||||
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, geminiBody)
|
||||
if err != nil {
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("failed to wrap request: %w", err)
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("failed to build upstream request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadGateway, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
return resp.Body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@ -600,120 +600,6 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_InjectsSessionIDIntoWrappedRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||
req.Header.Set("session_id", "session-header-1")
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-session-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 16,
|
||||
Name: "acc-gemini-session",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", false, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
|
||||
var wrapped map[string]any
|
||||
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
|
||||
requestNode, ok := wrapped["request"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "session-header-1", requestNode["sessionId"])
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PropagatesSessionHeaderIntoClaudeTransform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body := []byte(`{
|
||||
"model":"claude-sonnet-4-5",
|
||||
"max_tokens":64,
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("session_id", "session-header-1")
|
||||
c.Request = req
|
||||
|
||||
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-session-claude-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
Name: "acc-claude-session",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"project_id": "project-1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
|
||||
var wrapped antigravity.V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(upstream.requestBodies[0], &wrapped))
|
||||
require.Equal(t, "session-header-1", wrapped.Request.SessionID)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
|
||||
@ -29,9 +29,8 @@ type AntigravityAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成 Google OAuth 授权链接。
|
||||
// isEnterprise=true 时生成企业账号授权链接(使用企业 client_id)。
|
||||
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, isEnterprise bool) (*AntigravityAuthURLResult, error) {
|
||||
// GenerateAuthURL 生成 Google OAuth 授权链接
|
||||
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
|
||||
state, err := antigravity.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 state 失败: %w", err)
|
||||
@ -59,13 +58,12 @@ func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
IsEnterprise: isEnterprise,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
|
||||
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge, isEnterprise)
|
||||
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
return &AntigravityAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
@ -91,7 +89,6 @@ type AntigravityTokenInfo struct {
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
IsEnterprise bool `json:"is_enterprise,omitempty"`
|
||||
ProjectIDMissing bool `json:"-"`
|
||||
PlanType string `json:"-"`
|
||||
PrivacyMode string `json:"-"`
|
||||
@ -122,8 +119,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
|
||||
// 交换 token(使用 session 中记录的账号类型)
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier, session.IsEnterprise)
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换失败: %w", err)
|
||||
}
|
||||
@ -140,7 +137,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
IsEnterprise: session.IsEnterprise,
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
@ -170,9 +166,8 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 token。
|
||||
// isEnterprise=true 时使用企业 OAuth client_id/secret。
|
||||
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string, isEnterprise bool) (*AntigravityTokenInfo, error) {
|
||||
// RefreshToken 刷新 token
|
||||
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= 3; attempt++ {
|
||||
@ -188,7 +183,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken, isEnterprise)
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err == nil {
|
||||
now := time.Now()
|
||||
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
|
||||
@ -200,7 +195,6 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
IsEnterprise: isEnterprise,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -217,9 +211,8 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)。
|
||||
// isEnterprise=true 时使用企业 OAuth client 刷新。
|
||||
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64, isEnterprise bool) (*AntigravityTokenInfo, error) {
|
||||
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)
|
||||
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
@ -228,8 +221,8 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新 token:先按调用方指定类型刷新;若报 client 不匹配再尝试另一侧。
|
||||
tokenInfo, err := s.refreshTokenAutoFallback(ctx, refreshToken, proxyURL, isEnterprise)
|
||||
// 刷新 token
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -281,32 +274,6 @@ func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isClientMismatchOAuthError 判断是否为 OAuth client 不匹配错误(用于触发个人/企业切换)。
|
||||
// 与 isNonRetryableAntigravityOAuthError 不同:这里只识别 client 相关错误,不包含 invalid_grant。
|
||||
func isClientMismatchOAuthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "invalid_client") ||
|
||||
strings.Contains(msg, "unauthorized_client")
|
||||
}
|
||||
|
||||
// refreshTokenAutoFallback 先按指定类型刷新;若遇 client 不匹配错误则切换到另一侧。
|
||||
// 本函数不承担网络层重试(由内部 RefreshToken 处理)。
|
||||
func (s *AntigravityOAuthService) refreshTokenAutoFallback(ctx context.Context, refreshToken, proxyURL string, preferEnterprise bool) (*AntigravityTokenInfo, error) {
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, preferEnterprise)
|
||||
if err == nil {
|
||||
return tokenInfo, nil
|
||||
}
|
||||
if !isClientMismatchOAuthError(err) {
|
||||
return nil, err
|
||||
}
|
||||
// 切换另一侧账号类型重试
|
||||
fmt.Printf("[AntigravityOAuth] client 不匹配,切换账号类型重试:%v → %v\n", preferEnterprise, !preferEnterprise)
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL, !preferEnterprise)
|
||||
}
|
||||
|
||||
// RefreshAccountToken 刷新账户的 token
|
||||
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
@ -318,8 +285,6 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
return nil, fmt.Errorf("无可用的 refresh_token")
|
||||
}
|
||||
|
||||
isEnterprise := account.GetCredentialAsBool("is_gcp_tos")
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
@ -328,7 +293,7 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL, isEnterprise)
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -495,7 +460,6 @@ func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *Antigravity
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
"is_gcp_tos": tokenInfo.IsEnterprise,
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
|
||||
@ -5,12 +5,13 @@ package service
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
|
||||
|
||||
@ -1,187 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAntigravityFullFlow 完整流程测试
|
||||
// 模拟从 HTTP 处理器到最终响应的完整路径
|
||||
func TestAntigravityFullFlow(t *testing.T) {
|
||||
t.Log("🔥 启动 Antigravity 完整流程测试...")
|
||||
t.Log("")
|
||||
|
||||
// 构造测试账号数据(使用提供的凭证)
|
||||
proxyID := int64(9)
|
||||
account := &Account{
|
||||
ID: 68,
|
||||
Name: "PriesJosephe139@gmail.com",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213",
|
||||
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
|
||||
"email": "priesjosephe139@gmail.com",
|
||||
"expires_at": "1775903154",
|
||||
"project_id": "kinetic-sum-r3tp7",
|
||||
"plan_type": "Free",
|
||||
},
|
||||
ProxyID: &proxyID,
|
||||
Concurrency: 100,
|
||||
}
|
||||
|
||||
// 测试路由决策逻辑
|
||||
t.Run("RouteAntigravityTest", func(t *testing.T) {
|
||||
// 验证账号类型,决定使用哪条路径
|
||||
t.Logf("📌 账号类型判断:")
|
||||
t.Logf(" Platform: %s (期望: antigravity)", account.Platform)
|
||||
t.Logf(" Type: %s (期望: oauth)", account.Type)
|
||||
t.Logf("")
|
||||
|
||||
// 模拟 routeAntigravityTest 的决策逻辑
|
||||
var testPath string
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
testPath = "APIKey 路径 (Claude/Gemini 直接连接)"
|
||||
} else if account.Platform == PlatformAntigravity {
|
||||
testPath = "OAuth/Upstream 路径 (使用 AntigravityGatewayService.TestConnection)"
|
||||
} else {
|
||||
testPath = "未知路径 (❌ 错误)"
|
||||
}
|
||||
|
||||
t.Logf("✅ 将使用: %s", testPath)
|
||||
t.Logf("")
|
||||
})
|
||||
|
||||
// 测试完整的错误处理流程
|
||||
t.Run("ErrorHandlingPathway", func(t *testing.T) {
|
||||
t.Logf("📋 错误处理流程图:")
|
||||
t.Logf("")
|
||||
t.Logf("1️⃣ HTTP Handler (account_handler.go:671)")
|
||||
t.Logf(" ↓")
|
||||
t.Logf(" accountTestService.TestAccountConnection()")
|
||||
t.Logf(" ↓")
|
||||
t.Logf("2️⃣ AccountTestService.routeAntigravityTest()")
|
||||
t.Logf(" ├─ Platform check: antigravity ✓")
|
||||
t.Logf(" ├─ Type check: oauth ✓")
|
||||
t.Logf(" └─ Call: testAntigravityAccountConnection()")
|
||||
t.Logf(" ↓")
|
||||
t.Logf("3️⃣ AccountTestService.testAntigravityAccountConnection()")
|
||||
t.Logf(" ├─ Send SSE 'test_start' event")
|
||||
t.Logf(" ├─ Call: AntigravityGatewayService.TestConnection()")
|
||||
t.Logf(" │ ├─ Get access token")
|
||||
t.Logf(" │ ├─ Get project_id")
|
||||
t.Logf(" │ ├─ Build request body")
|
||||
t.Logf(" │ ├─ Call: antigravityRetryLoop()")
|
||||
t.Logf(" │ │ ├─ Execute HTTP request to Google API")
|
||||
t.Logf(" │ │ ├─ Parse response")
|
||||
t.Logf(" │ │ └─ Handle errors (rate limit, auth, etc.)")
|
||||
t.Logf(" │ └─ Return result or error")
|
||||
t.Logf(" ├─ If error: sendErrorAndEnd(error_message)")
|
||||
t.Logf(" ├─ If success: sendEvent('content', response_text)")
|
||||
t.Logf(" └─ Send SSE 'test_complete' event")
|
||||
t.Logf(" ↓")
|
||||
t.Logf("4️⃣ Response to Client (SSE 流)")
|
||||
t.Logf(" ├─ Content-Type: text/event-stream")
|
||||
t.Logf(" ├─ Event: test_start")
|
||||
t.Logf(" ├─ Event: content (或 error)")
|
||||
t.Logf(" └─ Event: test_complete")
|
||||
t.Logf("")
|
||||
})
|
||||
|
||||
// 诊断 "IT" 错误的可能来源
|
||||
t.Run("DiagnoseITError", func(t *testing.T) {
|
||||
t.Logf("🔍 分析 'IT' 错误可能的来源:")
|
||||
t.Logf("")
|
||||
t.Logf("❓ 场景 1: 错误被截断")
|
||||
t.Logf(" 原始错误可能是:")
|
||||
t.Logf(" - 'INVALID_TOKEN' → truncated to 'IT'")
|
||||
t.Logf(" - 'INTERNAL_ERROR' → truncated to 'IT'")
|
||||
t.Logf(" - 'INVALID_GRANT' → truncated to 'IT'")
|
||||
t.Logf(" - 'INTERNAL_ERROR...' → first 2 chars 'IN' not 'IT'")
|
||||
t.Logf("")
|
||||
t.Logf("❓ 场景 2: 错误来自特定的代码点")
|
||||
t.Logf(" 可能出现 'IT' 的地方:")
|
||||
t.Logf(" - SSE stream 中的错误字符")
|
||||
t.Logf(" - HTTP response body 中的 JSON 解析错误")
|
||||
t.Logf(" - Google API 返回的错误代码 (如果 Google API 返回 'IT' 作为错误)")
|
||||
t.Logf("")
|
||||
t.Logf("❓ 场景 3: 特殊的错误代码")
|
||||
t.Logf(" 需要检查:")
|
||||
t.Logf(" - 是否存在名为 'IT' 的错误常量?")
|
||||
t.Logf(" - Google RPC 状态码中是否有 'IT'?")
|
||||
t.Logf(" - 特定的错误处理中是否会生成 'IT'?")
|
||||
t.Logf("")
|
||||
})
|
||||
|
||||
// 完整的调试检查清单
|
||||
t.Run("DebugChecklist", func(t *testing.T) {
|
||||
t.Logf("✅ 完整的调试检查清单:")
|
||||
t.Logf("")
|
||||
t.Logf("1. 验证账号信息:")
|
||||
t.Logf(" [ ] Account ID: %d", account.ID)
|
||||
t.Logf(" [ ] Platform: %s", account.Platform)
|
||||
t.Logf(" [ ] Type: %s", account.Type)
|
||||
t.Logf(" [ ] Access Token: %s... (长度: %d)",
|
||||
account.GetCredential("access_token")[:20],
|
||||
len(account.GetCredential("access_token")))
|
||||
t.Logf(" [ ] Project ID: %s", account.GetCredential("project_id"))
|
||||
t.Logf("")
|
||||
t.Logf("2. 验证请求路径:")
|
||||
t.Logf(" [ ] routeAntigravityTest 选择了正确的路径")
|
||||
t.Logf(" [ ] testAntigravityAccountConnection 被调用")
|
||||
t.Logf(" [ ] AntigravityGatewayService.TestConnection 被调用")
|
||||
t.Logf("")
|
||||
t.Logf("3. 捕获详细错误信息:")
|
||||
t.Logf(" [ ] 错误的完整字符串(不仅仅是 'IT')")
|
||||
t.Logf(" [ ] 错误的类型(type)")
|
||||
t.Logf(" [ ] 错误发生的确切代码行")
|
||||
t.Logf(" [ ] HTTP 状态码(如有)")
|
||||
t.Logf(" [ ] HTTP 响应体(如有)")
|
||||
t.Logf("")
|
||||
t.Logf("4. 验证 SSE 流处理:")
|
||||
t.Logf(" [ ] 错误事件的 type 字段")
|
||||
t.Logf(" [ ] 错误事件的 error 字段内容")
|
||||
t.Logf(" [ ] 是否有 UTF-8 编码问题")
|
||||
t.Logf("")
|
||||
})
|
||||
|
||||
// 建议的实际代码改进
|
||||
t.Run("SuggestedCodeFixes", func(t *testing.T) {
|
||||
t.Logf("🔧 建议的代码改进:")
|
||||
t.Logf("")
|
||||
t.Logf("1. 在 testAntigravityAccountConnection 中增加日志:")
|
||||
t.Logf(" ```go")
|
||||
t.Logf(" result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID)")
|
||||
t.Logf(" if err != nil {")
|
||||
t.Logf(" log.Printf(\"[ERROR] TestConnection failed: type=%%T, error=%%v, msg='%%s'\", err, err, err.Error())")
|
||||
t.Logf(" return s.sendErrorAndEnd(c, err.Error())")
|
||||
t.Logf(" }")
|
||||
t.Logf(" ```")
|
||||
t.Logf("")
|
||||
t.Logf("2. 在 sendErrorAndEnd 中增加详细日志:")
|
||||
t.Logf(" ```go")
|
||||
t.Logf(" func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, msg string) error {")
|
||||
t.Logf(" log.Printf(\"[SEND_ERROR] msg='%%s' (len=%%d, bytes=%%v)\", msg, len(msg), []byte(msg))")
|
||||
t.Logf(" s.sendEvent(c, TestEvent{Type: \"test_error\", Error: msg, Success: false})")
|
||||
t.Logf(" return nil")
|
||||
t.Logf(" }")
|
||||
t.Logf(" ```")
|
||||
t.Logf("")
|
||||
t.Logf("3. 检查 TestConnection 中的错误处理:")
|
||||
t.Logf(" 在 antigravity_gateway_service.go 的 TestConnection 函数中")
|
||||
t.Logf(" 追踪每个错误返回点的错误信息")
|
||||
t.Logf("")
|
||||
})
|
||||
|
||||
// 最后的总结
|
||||
t.Log("")
|
||||
t.Log("📊 测试摘要:")
|
||||
t.Log("✅ 账号凭证验证: 通过")
|
||||
t.Log("✅ 路由逻辑验证: 通过")
|
||||
t.Log("⚠️ 实际错误诊断: 需要在完整环境中运行")
|
||||
t.Log("")
|
||||
t.Log("下一步:")
|
||||
t.Log("1. 添加建议的代码日志")
|
||||
t.Log("2. 重新运行 HTTP 测试")
|
||||
t.Log("3. 收集完整的错误信息")
|
||||
t.Log("4. 分析并修复根本原因")
|
||||
}
|
||||
@ -1,188 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestHTTPResponseFlow 测试完整的 HTTP 请求-响应流,看客户端会收到什么
|
||||
func TestHTTPResponseFlow(t *testing.T) {
|
||||
t.Log("🔥 模拟完整的 HTTP 请求-响应流...")
|
||||
t.Log("")
|
||||
|
||||
// 创建一个模拟的服务
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 模拟账号测试端点
|
||||
router.POST("/api/v1/admin/accounts/:id/test", func(c *gin.Context) {
|
||||
// 模拟返回错误的情况
|
||||
|
||||
// 设置 SSE 头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
// 发送测试开始事件
|
||||
event1 := map[string]interface{}{
|
||||
"type": "test_start",
|
||||
"model": "claude-opus-4-6",
|
||||
}
|
||||
jsonData1, _ := json.Marshal(event1)
|
||||
c.Writer.WriteString("data: " + string(jsonData1) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
// 模拟一个错误:比如 "INVALID_TOKEN" 或其他上游错误
|
||||
// 这里我们故意测试不同的错误信息来看 curl 会显示什么
|
||||
|
||||
errorMessages := []string{
|
||||
"INVALID_TOKEN",
|
||||
"INTERNAL_ERROR",
|
||||
"Invalid authentication credentials",
|
||||
"Th", // 测试短错误
|
||||
"IT", // 直接测试 "IT"
|
||||
}
|
||||
|
||||
selectedError := errorMessages[3] // 选择第 4 个:这应该显示为 "Th" 而不是 "IT"
|
||||
|
||||
event2 := map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": selectedError,
|
||||
"success": false,
|
||||
}
|
||||
jsonData2, _ := json.Marshal(event2)
|
||||
c.Writer.WriteString("data: " + string(jsonData2) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
// 发送完成事件
|
||||
event3 := map[string]interface{}{
|
||||
"type": "test_complete",
|
||||
"success": false,
|
||||
}
|
||||
jsonData3, _ := json.Marshal(event3)
|
||||
c.Writer.WriteString("data: " + string(jsonData3) + "\n\n")
|
||||
c.Writer.Flush()
|
||||
|
||||
t.Logf("📤 服务器发送的错误: '%s'", selectedError)
|
||||
})
|
||||
|
||||
// 测试 1: 发送 HTTP 请求
|
||||
t.Run("SendRequestAndCheckResponse", func(t *testing.T) {
|
||||
t.Log("步骤 1: 发送 HTTP 请求...")
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test",
|
||||
bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6"}`)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
t.Log("✅ 请求已发送")
|
||||
t.Log("")
|
||||
|
||||
// 步骤 2: 检查响应
|
||||
t.Log("步骤 2: 分析 HTTP 响应...")
|
||||
t.Logf(" HTTP Status: %d", w.Code)
|
||||
t.Logf(" Content-Type: %s", w.Header().Get("Content-Type"))
|
||||
t.Log("")
|
||||
|
||||
// 步骤 3: 读取 SSE 响应
|
||||
t.Log("步骤 3: 读取 SSE 事件...")
|
||||
body := w.Body.String()
|
||||
t.Logf(" 响应总长度: %d 字节", len(body))
|
||||
t.Log("")
|
||||
|
||||
// 解析 SSE 事件
|
||||
lines := bytes.Split([]byte(body), []byte("\n\n"))
|
||||
for i, line := range lines {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 去掉 "data: " 前缀
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal(data, &event)
|
||||
if err != nil {
|
||||
t.Logf(" 事件 %d: [解析失败] %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(" 事件 %d:", i)
|
||||
t.Logf(" type: %v", event["type"])
|
||||
|
||||
if errMsg, ok := event["error"]; ok {
|
||||
t.Logf(" error: %v (长度: %d)", errMsg, len(errMsg.(string)))
|
||||
|
||||
// 这就是 curl 会看到的错误信息
|
||||
errStr := errMsg.(string)
|
||||
if errStr == "IT" {
|
||||
t.Logf(" ✓ 发现 'IT' 错误!")
|
||||
} else if errStr == "Th" {
|
||||
t.Logf(" ℹ️ 这是 'Th' 而不是 'IT'")
|
||||
} else {
|
||||
t.Logf(" ℹ️ 实际错误: '%s'", errStr)
|
||||
}
|
||||
}
|
||||
|
||||
if model, ok := event["model"]; ok {
|
||||
t.Logf(" model: %v", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("")
|
||||
t.Log("📋 完整的原始响应:")
|
||||
t.Logf("%s", body)
|
||||
})
|
||||
|
||||
// 测试 2: 模拟真实的 curl 请求
|
||||
t.Run("SimulateRealCurlRequest", func(t *testing.T) {
|
||||
t.Log("步骤: 模拟真实 curl 命令...")
|
||||
t.Log("")
|
||||
|
||||
// 发送请求
|
||||
req := httptest.NewRequest("POST", "/api/v1/admin/accounts/68/test",
|
||||
bytes.NewReader([]byte(`{"model_id":"claude-opus-4-6","prompt":""}`)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 模拟 curl 读取响应
|
||||
body := w.Body.String()
|
||||
|
||||
t.Log("curl 会看到:")
|
||||
t.Log("```")
|
||||
t.Log(body)
|
||||
t.Log("```")
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:提取 SSE 事件中的错误信息
|
||||
func extractErrorFromSSE(sseBody string) string {
|
||||
lines := bytes.Split([]byte(sseBody), []byte("\n\n"))
|
||||
for _, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
continue
|
||||
}
|
||||
if errMsg, ok := event["error"]; ok {
|
||||
return errMsg.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -1,213 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAntigravityCredentialsValidation 单例测试:验证给定的 Antigravity 账号凭证有效性
|
||||
// 本测试使用服务器的真实代码函数,不依赖 HTTP 层,模拟云端场景
|
||||
func TestAntigravityCredentialsValidation(t *testing.T) {
|
||||
// 测试数据:来自你提供的账号信息
|
||||
// ID: 68, 平台: antigravity, 类型: oauth
|
||||
proxyID := int64(9)
|
||||
testAccount := &Account{
|
||||
ID: 68,
|
||||
Name: "PriesJosephe139@gmail.com",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "ya29.a0Aa7MYioHycPKQ7xWQguns0VlftxfCwTqn2OY8zVosNMagLLGd5DXWFXpySKgfroGkqihr4Yrwauy1AXfQyvWB-F_4qt46DiEw1sCmaCNmDwjruUiWK7Km7vh7djBONbgruyL0N9_b3aSLi-Zf3llY5FbWZqcNky13gaVUaW0ioxEDVOZuKxYw82yVXvVEqPRXF7cetjUJbLdzwaCgYKAZwSARMSFQHGX2MiqNlICLPPA-_u6WHPBLiUJQ0213",
|
||||
"refresh_token": "1//06QXt2rakQERPCgYIARAAGAYSNwF-L9IrR672cwDMnyJS128asGMnBbrrdiN39XoS-FN6TUrG7pPxnDSEHYUV4WHDntB7qd2EPwo",
|
||||
"email": "priesjosephe139@gmail.com",
|
||||
"expires_at": "1775903154",
|
||||
"project_id": "kinetic-sum-r3tp7",
|
||||
"plan_type": "Free",
|
||||
},
|
||||
ProxyID: &proxyID,
|
||||
Concurrency: 100,
|
||||
}
|
||||
|
||||
// 测试 1: 验证账号凭证完整性
|
||||
t.Run("ValidateAccountCredentials", func(t *testing.T) {
|
||||
if testAccount.ID == 0 {
|
||||
t.Fatal("Account ID is missing")
|
||||
}
|
||||
if testAccount.Platform != PlatformAntigravity {
|
||||
t.Fatalf("Expected platform %s, got %s", PlatformAntigravity, testAccount.Platform)
|
||||
}
|
||||
if testAccount.Type != AccountTypeOAuth {
|
||||
t.Fatalf("Expected type %s, got %s", AccountTypeOAuth, testAccount.Type)
|
||||
}
|
||||
|
||||
// 验证必要的凭证字段
|
||||
accessToken := testAccount.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
t.Fatal("Access token is missing")
|
||||
}
|
||||
refreshToken := testAccount.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
t.Fatal("Refresh token is missing")
|
||||
}
|
||||
projectID := testAccount.GetCredential("project_id")
|
||||
if projectID == "" {
|
||||
t.Fatal("Project ID is missing")
|
||||
}
|
||||
|
||||
t.Log("✅ 账号凭证完整性验证通过")
|
||||
t.Logf(" Account ID: %d, Email: %s, ProjectID: %s", testAccount.ID, testAccount.GetCredential("email"), projectID)
|
||||
})
|
||||
|
||||
// 测试 2: 测试 token 映射和模型验证
|
||||
t.Run("ValidateModelMapping", func(t *testing.T) {
|
||||
testModels := []string{
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
"gemini-3-pro-preview",
|
||||
}
|
||||
|
||||
for _, model := range testModels {
|
||||
t.Logf("✓ Model %s is supported for account", model)
|
||||
}
|
||||
|
||||
t.Log("✅ 模型映射验证通过")
|
||||
})
|
||||
|
||||
// 测试 3: 构建测试请求(不实际发送,只验证格式)
|
||||
t.Run("BuildTestRequest", func(t *testing.T) {
|
||||
projectID := testAccount.GetCredential("project_id")
|
||||
if projectID == "" {
|
||||
t.Skip("Project ID not available, skipping request building")
|
||||
}
|
||||
|
||||
// 构建 Claude 测试请求的简化版本
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": ".",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"max_tokens": 1,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(claudeReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("✅ 请求体构建成功,大小: %d bytes", len(requestBody))
|
||||
if len(requestBody) > 200 {
|
||||
t.Logf(" 请求格式: %s...", string(requestBody[:200]))
|
||||
} else {
|
||||
t.Logf(" 请求格式: %s", string(requestBody))
|
||||
}
|
||||
})
|
||||
|
||||
// 测试 4: 验证 Token 信息格式
|
||||
t.Run("ValidateTokenInfo", func(t *testing.T) {
|
||||
expiresAtStr := testAccount.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
t.Log("⚠️ No expires_at timestamp found")
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析时间戳
|
||||
expiresAtUnix, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil {
|
||||
expiresAt := time.Unix(expiresAtUnix, 0)
|
||||
now := time.Now()
|
||||
if expiresAt.After(now) {
|
||||
remainingTime := expiresAt.Sub(now)
|
||||
t.Logf("✅ Token 有效期检查通过")
|
||||
t.Logf(" 过期时间: %s (还有 %v)", expiresAt.Format("2006-01-02 15:04:05 MST"), remainingTime)
|
||||
} else {
|
||||
t.Logf("⚠️ Token 已过期: %s", expiresAt.Format("2006-01-02 15:04:05 MST"))
|
||||
t.Log(" 预期行为: 应该刷新 refresh_token")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 测试 5: 创建 Antigravity 客户端并验证连接(如果可行)
|
||||
t.Run("InitializeAntigravityClient", func(t *testing.T) {
|
||||
// 使用账号的代理信息初始化客户端
|
||||
if testAccount.ProxyID != nil {
|
||||
t.Logf("Account uses proxy ID: %d", *testAccount.ProxyID)
|
||||
}
|
||||
|
||||
t.Log("📌 Antigravity 客户端初始化代码路径:")
|
||||
t.Log(" 1. 使用 accessToken 创建 antigravity.NewClient(proxyURL)")
|
||||
t.Log(" 2. 调用 client.LoadCodeAssist(ctx, accessToken) 验证凭证")
|
||||
t.Log(" 3. 检查响应中的 CloudAICompanionProject 字段")
|
||||
t.Log("")
|
||||
t.Log(" 预期行为:")
|
||||
t.Log(" ✓ projectID == 'kinetic-sum-r3tp7'")
|
||||
t.Log(" ✓ statusCode 200")
|
||||
t.Log(" ✓ 无错误返回")
|
||||
})
|
||||
|
||||
// 测试 6: 验证账号支持的操作
|
||||
t.Run("VerifyAccountOperations", func(t *testing.T) {
|
||||
operations := []string{
|
||||
"GetAccessToken",
|
||||
"RefreshToken",
|
||||
"LoadCodeAssist",
|
||||
"GetUserInfo",
|
||||
"SetPrivacy",
|
||||
}
|
||||
|
||||
for _, op := range operations {
|
||||
t.Logf("✓ Operation supported: %s", op)
|
||||
}
|
||||
|
||||
t.Log("")
|
||||
t.Log("✅ 账号支持的操作列表验证通过")
|
||||
})
|
||||
|
||||
// 测试 7: 文档化测试流程(实际调用时的步骤)
|
||||
t.Run("DocumentTestFlow", func(t *testing.T) {
|
||||
t.Log("📝 本地测试 Antigravity 账号的完整流程:")
|
||||
t.Log("")
|
||||
t.Log("步骤 1: 初始化服务")
|
||||
t.Log(" - accountRepo: 从数据库获取账号")
|
||||
t.Log(" - tokenProvider: Antigravity Token 提供者")
|
||||
t.Log(" - httpUpstream: HTTP 请求执行器")
|
||||
t.Log(" - gatewayService: Antigravity 网关服务")
|
||||
t.Log("")
|
||||
t.Log("步骤 2: 验证账号凭证")
|
||||
t.Log(" account := accountRepo.GetByID(ctx, 68)")
|
||||
t.Log(" accessToken := account.GetCredential('access_token')")
|
||||
t.Log(" projectID := account.GetCredential('project_id')")
|
||||
t.Log("")
|
||||
t.Log("步骤 3: 构建测试请求")
|
||||
t.Log(" requestBody := gatewayService.buildClaudeTestRequest(projectID, 'claude-opus-4-6')")
|
||||
t.Log("")
|
||||
t.Log("步骤 4: 执行请求")
|
||||
t.Log(" result := gatewayService.TestConnection(ctx, account, 'claude-opus-4-6')")
|
||||
t.Log("")
|
||||
t.Log("步骤 5: 处理结果")
|
||||
t.Log(" if err != nil {")
|
||||
t.Log(" // 记录错误详情")
|
||||
t.Log(" }")
|
||||
t.Log("")
|
||||
t.Log("⚠️ 当前问题:返回了 'IT' 错误")
|
||||
t.Log(" 这可能表示:")
|
||||
t.Log(" 1. 错误消息被截断或编码错误")
|
||||
t.Log(" 2. HTTP 响应体包含不完整的错误文本")
|
||||
t.Log(" 3. 上游 API 返回的错误被不正确地处理")
|
||||
})
|
||||
|
||||
t.Log("")
|
||||
t.Log("✅ 所有本地验证测试完成!")
|
||||
t.Log("")
|
||||
t.Log("下一步:在实际环境中运行完整测试")
|
||||
}
|
||||
@ -1,194 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// TestWithSOCKS5Proxy 使用指定的 SOCKS5 代理调用上游 API
|
||||
func TestWithSOCKS5Proxy(t *testing.T) {
|
||||
t.Log("🔥 使用 SOCKS5 代理调用 Google API...")
|
||||
t.Log("")
|
||||
|
||||
// SOCKS5 代理配置
|
||||
proxyAddr := "socks5://gostuser:fastapipwd@216.167.89.210:8760"
|
||||
accessToken := "ya29.a0Aa7MYipSteGdNdr486LvE0xu_RrcbFjSSFZa5jGTf94nPv6NLKEnnRziPSVA_3ncadMlWnUQN8el05uvYac3rk9rOuaEC3jAUq02ejAcayg8tBn9CJT2IGuMsFDRPbfvHwXVHvY-hPGaklubxMIgfckRYsGC7YTpJPprH8kNGG-7ZWf3PvcVGcSrLWhi8FX6Yq1at5OdC1deNAaCgYKAVASARMSFQHGX2Mi2yEN9AChtlJFBwZ_spYEoQ0213"
|
||||
|
||||
t.Log("📌 代理信息:")
|
||||
t.Logf(" 代理地址: %s", proxyAddr)
|
||||
t.Logf(" 访问令牌: %s... (长度: %d)", accessToken[:30], len(accessToken))
|
||||
t.Log("")
|
||||
|
||||
// 创建上下文和超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 步骤 1: 设置 SOCKS5 代理
|
||||
t.Run("SetupSOCKS5Proxy", func(t *testing.T) {
|
||||
t.Log("步骤 1: 配置 SOCKS5 代理...")
|
||||
|
||||
// 解析代理 URL
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 解析代理 URL 失败: %v", err)
|
||||
}
|
||||
t.Logf(" ✓ 代理 URL 解析成功")
|
||||
t.Logf(" Scheme: %s", proxyURL.Scheme)
|
||||
t.Logf(" Host: %s", proxyURL.Host)
|
||||
t.Logf(" User: %s", proxyURL.User.Username())
|
||||
t.Log("")
|
||||
|
||||
// 创建代理拨号器
|
||||
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 创建代理拨号器失败: %v", err)
|
||||
}
|
||||
t.Log(" ✓ 代理拨号器创建成功")
|
||||
t.Log("")
|
||||
|
||||
// 创建自定义传输
|
||||
transport := &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
|
||||
// 创建自定义 HTTP 客户端
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
t.Log(" ✓ HTTP 客户端创建成功")
|
||||
t.Log("")
|
||||
|
||||
// 步骤 2: 测试代理连接
|
||||
t.Log("步骤 2: 测试代理连接...")
|
||||
|
||||
// 尝试一个简单的 HTTP 请求来测试代理
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Logf("❌ 创建测试请求失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Logf("❌ 通过代理访问 Google 失败: %v", err)
|
||||
t.Log(" (这可能表示代理配置或网络连接有问题)")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
t.Logf(" ✓ 代理连接成功!")
|
||||
t.Logf(" HTTP Status: %d", resp.StatusCode)
|
||||
t.Log("")
|
||||
})
|
||||
|
||||
// 步骤 3: 使用代理调用 Antigravity API
|
||||
t.Run("CallAntigravityWithProxy", func(t *testing.T) {
|
||||
t.Log("步骤 3: 通过代理调用 Antigravity API...")
|
||||
t.Log("")
|
||||
|
||||
// 解析代理 URL
|
||||
proxyURL, err := url.Parse(proxyAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 解析代理 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建代理拨号器
|
||||
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 创建代理拨号器失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建自定义传输
|
||||
transport := &http.Transport{
|
||||
Dial: dialer.Dial,
|
||||
}
|
||||
|
||||
// 这里我们需要修改 antigravity.Client 来使用自定义的 HTTP 客户端
|
||||
// 但由于 antigravity.NewClient 可能不支持自定义客户端,
|
||||
// 我们直接创建一个 HTTP 客户端来调用 API
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
t.Log(" 正在调用 Google Cloud Code API...")
|
||||
t.Log("")
|
||||
|
||||
// 直接构造 API 请求
|
||||
apiURL := "https://daily-cloudcode-pa.googleapis.com/v1internal:loadCodeAssist"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ 创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
// 添加认证头
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "Antigravity Client")
|
||||
|
||||
t.Logf(" 📤 请求信息:")
|
||||
t.Logf(" URL: %s", apiURL)
|
||||
t.Logf(" Method: POST")
|
||||
t.Logf(" Auth: Bearer %s...", accessToken[:30])
|
||||
t.Log("")
|
||||
|
||||
// 发送请求
|
||||
t.Log(" ⏳ 正在等待响应...")
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
t.Logf("❌ API 调用失败:")
|
||||
t.Logf(" 错误类型: %T", err)
|
||||
t.Logf(" 错误信息: %v", err)
|
||||
t.Logf(" 错误字符串: %s", err.Error())
|
||||
t.Log("")
|
||||
|
||||
// 分析错误
|
||||
errStr := err.Error()
|
||||
if len(errStr) >= 2 {
|
||||
t.Logf("📊 错误的前 5 个字符: '%s'", errStr[:min(5, len(errStr))])
|
||||
if errStr[:2] == "IT" {
|
||||
t.Logf(" ✓ 找到了! 这就是 'IT' 错误的来源!")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
t.Logf("✅ API 调用成功!")
|
||||
t.Logf(" HTTP Status: %d", resp.StatusCode)
|
||||
t.Logf(" Content-Type: %s", resp.Header.Get("Content-Type"))
|
||||
t.Log("")
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Logf("❌ 读取响应失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Log("📋 API 响应:")
|
||||
if resp.StatusCode == 200 {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err == nil {
|
||||
jsonBytes, _ := json.MarshalIndent(result, " ", " ")
|
||||
t.Logf(" %s", string(jsonBytes))
|
||||
} else {
|
||||
t.Logf(" %s", string(respBody))
|
||||
}
|
||||
} else {
|
||||
t.Logf(" 状态码: %d", resp.StatusCode)
|
||||
t.Logf(" 错误响应: %s", string(respBody))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -1,20 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldMarkTempUnschedulableForRefreshError(t *testing.T) {
|
||||
t.Run("skip global oauth client secret missing", func(t *testing.T) {
|
||||
err := errors.New(`token 刷新失败 (重试后): error: code=400 reason="ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING" message="missing antigravity oauth client_secret; set ANTIGRAVITY_OAUTH_CLIENT_SECRET" metadata=map[]`)
|
||||
require.False(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
|
||||
t.Run("allow account specific refresh error", func(t *testing.T) {
|
||||
err := errors.New("token 刷新失败 (重试后): invalid_grant")
|
||||
require.True(t, shouldMarkTempUnschedulableForRefreshError(err))
|
||||
})
|
||||
}
|
||||
@ -1,83 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// WarmupAntigravityAccount 预热新的 Antigravity 账号
|
||||
// 在账号创建后立即调用,避免首次请求的 503 延迟
|
||||
//
|
||||
// 预热流程:
|
||||
// 1. GetUserInfo - 验证 token 有效性
|
||||
// 2. LoadCodeAssist - 初始化项目信息
|
||||
// 3. FetchAvailableModels - 初始化模型列表
|
||||
//
|
||||
// 总耗时通常 4-6 秒,预热期间的失败不影响账号创建结果(非阻塞)
|
||||
func (s *AntigravityOAuthService) WarmupAntigravityAccount(ctx context.Context, accessToken, projectID, proxyURL string) {
|
||||
logger := slog.Default()
|
||||
|
||||
// 5 秒超时预热(防止卡住其他操作)
|
||||
warmupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
logger.Warn("antigravity_warmup_client_creation_failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
logger.Info("antigravity_account_warmup_completed", "elapsed_ms", elapsed.Milliseconds())
|
||||
}()
|
||||
|
||||
// Step 1: 验证 token
|
||||
_, err = client.GetUserInfo(warmupCtx, accessToken)
|
||||
if err != nil {
|
||||
logger.Warn("antigravity_warmup_get_user_info_failed", "error", err)
|
||||
// 继续后续步骤(部分失败不中止)
|
||||
}
|
||||
|
||||
// Step 2: 初始化项目信息
|
||||
_, _, err = client.LoadCodeAssist(warmupCtx, accessToken)
|
||||
if err != nil {
|
||||
logger.Warn("antigravity_warmup_load_code_assist_failed", "error", err)
|
||||
}
|
||||
|
||||
// Step 3: 初始化模型列表
|
||||
if projectID != "" {
|
||||
_, _, err := client.FetchAvailableModels(warmupCtx, accessToken, projectID)
|
||||
if err != nil {
|
||||
logger.Warn("antigravity_warmup_fetch_available_models_failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WarmupOptions 预热选项
|
||||
type WarmupOptions struct {
|
||||
// Async 为 true 时在后台预热(推荐)
|
||||
Async bool
|
||||
// Timeout 单次预热操作的超时时间
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// WarmupAntigravityAccountAsync 异步预热账号(推荐用法)
|
||||
func (s *AntigravityOAuthService) WarmupAntigravityAccountAsync(ctx context.Context, accessToken, projectID, proxyURL string, opts *WarmupOptions) {
|
||||
if opts == nil {
|
||||
opts = &WarmupOptions{
|
||||
Async: true,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
if opts.Async {
|
||||
go s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL)
|
||||
} else {
|
||||
s.WarmupAntigravityAccount(ctx, accessToken, projectID, proxyURL)
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,9 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// ChannelMonitorRequestTemplate 请求模板(service 层模型)。
|
||||
|
||||
@ -1,284 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// Attribution block constants matching real Claude Code 2.1.89.
|
||||
// Source: src/constants/system.ts + src/utils/fingerprint.ts
|
||||
const (
|
||||
// fingerprintSalt must match the hardcoded salt in the real CLI.
|
||||
// Source: extracted/src/utils/fingerprint.ts:8
|
||||
fingerprintSalt = "59cf53e54c78"
|
||||
)
|
||||
|
||||
type attributionBlockOptions struct {
|
||||
Entrypoint string
|
||||
Workload string
|
||||
OmitCCH bool
|
||||
}
|
||||
|
||||
// computeAttributionFingerprint computes a 3-character hex fingerprint
|
||||
// matching the algorithm in the real Claude Code CLI.
|
||||
//
|
||||
// Algorithm: SHA256(SALT + msg[4] + msg[7] + msg[20] + version)[:3]
|
||||
// Source: extracted/src/utils/fingerprint.ts:50-63
|
||||
func computeAttributionFingerprint(firstUserMessageText, cliVersion string) string {
|
||||
indices := [3]int{4, 7, 20}
|
||||
chars := make([]byte, 0, 3)
|
||||
for _, i := range indices {
|
||||
if i < len(firstUserMessageText) {
|
||||
chars = append(chars, firstUserMessageText[i])
|
||||
} else {
|
||||
chars = append(chars, '0')
|
||||
}
|
||||
}
|
||||
|
||||
input := fmt.Sprintf("%s%s%s", fingerprintSalt, string(chars), cliVersion)
|
||||
hash := sha256.Sum256([]byte(input))
|
||||
return hex.EncodeToString(hash[:])[:3]
|
||||
}
|
||||
|
||||
// extractFirstUserMessageText extracts text from the first user message in the body.
|
||||
// Handles both string content and array content (text blocks).
|
||||
func extractFirstUserMessageText(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var firstText string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if msg.Get("role").String() != "user" {
|
||||
return true // continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
firstText = content.String()
|
||||
return false // break
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
if block.Get("type").String() == "text" {
|
||||
firstText = block.Get("text").String()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return firstText
|
||||
}
|
||||
|
||||
// buildAttributionBlock builds the x-anthropic-billing-header attribution string
|
||||
// that real Claude Code injects as the first system text block.
|
||||
//
|
||||
// Format: x-anthropic-billing-header: cc_version=<VERSION>.<fingerprint>; cc_entrypoint=cli; cch=00000;
|
||||
// Source: extracted/src/constants/system.ts:73-95
|
||||
func buildAttributionBlock(cliVersion, fingerprint string, opts attributionBlockOptions) string {
|
||||
if claude.AttributionHeaderDisabled() {
|
||||
return ""
|
||||
}
|
||||
version := cliVersion + "." + fingerprint
|
||||
entrypoint := strings.TrimSpace(opts.Entrypoint)
|
||||
if entrypoint == "" {
|
||||
entrypoint = claude.CurrentEntrypoint()
|
||||
}
|
||||
workload := strings.TrimSpace(opts.Workload)
|
||||
if workload == "" {
|
||||
workload = claude.CurrentWorkload()
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(96)
|
||||
fmt.Fprintf(&b, "x-anthropic-billing-header: cc_version=%s; cc_entrypoint=%s;", version, entrypoint)
|
||||
if !opts.OmitCCH {
|
||||
// 2.1.89+ 的 Claude Code 在 1P / standard providers 下保留 cch=00000 占位符,
|
||||
// 由下游 attestation / signing 逻辑在需要时替换。
|
||||
b.WriteString(" cch=00000;")
|
||||
}
|
||||
if workload != "" {
|
||||
fmt.Fprintf(&b, " cc_workload=%s;", workload)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// injectAttributionBlock prepends the x-anthropic-billing-header attribution block
|
||||
// as the very first system text block in the request body.
|
||||
// This must come BEFORE the "You are Claude Code" block.
|
||||
//
|
||||
// The real CLI injects this as system[0] with cache_control: {type: "ephemeral"}.
|
||||
func injectAttributionBlock(body []byte, cliVersion string, opts attributionBlockOptions) []byte {
|
||||
// Compute fingerprint from the first user message
|
||||
firstMsgText := extractFirstUserMessageText(body)
|
||||
fingerprint := computeAttributionFingerprint(firstMsgText, cliVersion)
|
||||
attribution := buildAttributionBlock(cliVersion, fingerprint, opts)
|
||||
if attribution == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
// Build the attribution text block as JSON
|
||||
attrBlock, err := marshalAnthropicSystemTextBlock(attribution, true)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Warning: failed to build attribution block: %v", err)
|
||||
return body
|
||||
}
|
||||
|
||||
systemResult := gjson.GetBytes(body, "system")
|
||||
|
||||
// Handle the different system formats
|
||||
switch {
|
||||
case !systemResult.Exists() || systemResult.Type == gjson.Null:
|
||||
// No system field — inject just the attribution block
|
||||
newBody, err := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock}))
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
case systemResult.Type == gjson.String:
|
||||
// String system — convert to array: [attribution, original]
|
||||
origBlock, err := marshalAnthropicSystemTextBlock(systemResult.String(), false)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw([][]byte{attrBlock, origBlock}))
|
||||
if setErr != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
case systemResult.IsArray():
|
||||
// Array system — check if attribution already exists, prepend if not
|
||||
var items [][]byte
|
||||
alreadyHasAttribution := false
|
||||
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||
if item.Get("type").String() == "text" {
|
||||
text := item.Get("text").String()
|
||||
if len(text) > 30 && text[:30] == "x-anthropic-billing-header: cc" {
|
||||
alreadyHasAttribution = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if alreadyHasAttribution {
|
||||
return body
|
||||
}
|
||||
|
||||
items = append(items, attrBlock)
|
||||
systemResult.ForEach(func(_, item gjson.Result) bool {
|
||||
items = append(items, []byte(item.Raw))
|
||||
return true
|
||||
})
|
||||
newBody, setErr := sjson.SetRawBytes(body, "system", buildJSONArrayRaw(items))
|
||||
if setErr != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
|
||||
default:
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
// cliSessionEntry holds a cached session UUID with an expiration time.
|
||||
type cliSessionEntry struct {
|
||||
id string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// cliSessionCache stores per-account session UUIDs that rotate on a TTL.
|
||||
// Real CLI creates a new random UUID per process invocation; we approximate
|
||||
// this by rotating every 30-60 minutes (jittered per account).
|
||||
var (
|
||||
cliSessionCache = make(map[int64]cliSessionEntry)
|
||||
cliSessionCacheMu sync.Mutex
|
||||
)
|
||||
|
||||
// sessionTTLBase is the base TTL for session ID rotation.
|
||||
const sessionTTLBase = 30 * time.Minute
|
||||
|
||||
// generateSessionIDForAccount returns a per-account session UUID that rotates
|
||||
// periodically. Each account gets a random TTL jitter (0-30 min on top of
|
||||
// the 30 min base) so accounts don't all rotate simultaneously.
|
||||
func generateSessionIDForAccount(instanceSalt string, accountID int64) string {
|
||||
cliSessionCacheMu.Lock()
|
||||
defer cliSessionCacheMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if entry, ok := cliSessionCache[accountID]; ok && now.Before(entry.expiresAt) {
|
||||
return entry.id
|
||||
}
|
||||
|
||||
// Compute per-account jitter from a hash so the same account always gets
|
||||
// the same jitter within a process (avoids re-rolling on every rotation).
|
||||
jitterSeed := fmt.Sprintf("jitter:%s:%d", instanceSalt, accountID)
|
||||
h := sha256.Sum256([]byte(jitterSeed))
|
||||
jitterMinutes := int(h[0]) % 31 // 0-30 minutes
|
||||
ttl := sessionTTLBase + time.Duration(jitterMinutes)*time.Minute
|
||||
|
||||
newID := uuid.New().String()
|
||||
cliSessionCache[accountID] = cliSessionEntry{
|
||||
id: newID,
|
||||
expiresAt: now.Add(ttl),
|
||||
}
|
||||
return newID
|
||||
}
|
||||
|
||||
// reUserHome matches /Users/<username>/ or /home/<username>/ path segments.
|
||||
// Captures the prefix (/Users/ or /home/) so we can preserve it while replacing the username.
|
||||
var reUserHome = regexp.MustCompile(`(/(Users|home)/)[^/\s"']+/`)
|
||||
|
||||
// reEnvLine matches lines of the form "Key: value" for the environment block
|
||||
// fields injected by Claude Code's CLAUDE.md / sysprompt machinery.
|
||||
var reEnvLine = regexp.MustCompile(`(?m)^(Platform|Shell|OS Version|Working directory):.*$`)
|
||||
|
||||
// canonicalEnvValues maps environment block keys to their canonical replacements.
|
||||
// Values mirror cc-gateway's prompt_env config and represent a stock macOS dev machine.
|
||||
var canonicalEnvValues = map[string]string{
|
||||
"Platform": "Platform: darwin",
|
||||
"Shell": "Shell: zsh",
|
||||
"OS Version": "OS Version: Darwin 24.4.0",
|
||||
"Working directory": "Working directory: /Users/user/project",
|
||||
}
|
||||
|
||||
// NormalizeSystemPromptEnv rewrites environment-specific fields in a system
|
||||
// prompt text block to canonical values, preventing real machine fingerprinting.
|
||||
//
|
||||
// Handles two classes of leakage (matching cc-gateway rewriter.ts:rewritePromptText):
|
||||
// 1. "Platform: Windows / Linux / Darwin 25.x" → canonical darwin/zsh/Darwin 24.4.0
|
||||
// 2. "/Users/alice/" or "/home/bob/" → "/Users/user/"
|
||||
//
|
||||
// Only called on system prompt text blocks, never on user message content.
|
||||
func NormalizeSystemPromptEnv(text string) string {
|
||||
// Replace env-info lines with canonical values
|
||||
text = reEnvLine.ReplaceAllStringFunc(text, func(line string) string {
|
||||
for key, canonical := range canonicalEnvValues {
|
||||
if len(line) >= len(key) && line[:len(key)] == key {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
return line
|
||||
})
|
||||
|
||||
// Redact real usernames in home directory paths
|
||||
// e.g. /Users/alice/project -> /Users/user/project
|
||||
text = reUserHome.ReplaceAllString(text, "${1}user/")
|
||||
|
||||
return text
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestApplyClaudeRuntimeOptionalHeaders(t *testing.T) {
|
||||
t.Setenv("CLAUDE_CODE_CONTAINER_ID", "ctr-123")
|
||||
t.Setenv("CLAUDE_CODE_REMOTE_SESSION_ID", "remote-456")
|
||||
t.Setenv("CLAUDE_AGENT_SDK_CLIENT_APP", "desktop")
|
||||
t.Setenv("CLAUDE_CODE_ADDITIONAL_PROTECTION", "true")
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
|
||||
applyClaudeRuntimeOptionalHeaders(req)
|
||||
|
||||
if got := getHeaderRaw(req.Header, "x-claude-remote-container-id"); got != "ctr-123" {
|
||||
t.Fatalf("x-claude-remote-container-id = %q", got)
|
||||
}
|
||||
if got := getHeaderRaw(req.Header, "x-claude-remote-session-id"); got != "remote-456" {
|
||||
t.Fatalf("x-claude-remote-session-id = %q", got)
|
||||
}
|
||||
if got := getHeaderRaw(req.Header, "x-client-app"); got != "desktop" {
|
||||
t.Fatalf("x-client-app = %q", got)
|
||||
}
|
||||
if got := getHeaderRaw(req.Header, "x-anthropic-additional-protection"); got != "true" {
|
||||
t.Fatalf("x-anthropic-additional-protection = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAttributionBlock_UsesEntrypointAndWorkload(t *testing.T) {
|
||||
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "")
|
||||
|
||||
got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{
|
||||
Entrypoint: "sdk-cli",
|
||||
Workload: "cron",
|
||||
})
|
||||
want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=sdk-cli; cch=00000; cc_workload=cron;"
|
||||
if got != want {
|
||||
t.Fatalf("buildAttributionBlock() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAttributionBlock_OmitsCCHForBedrockLikeProviders(t *testing.T) {
|
||||
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "")
|
||||
|
||||
got := buildAttributionBlock("2.1.104", "abc", attributionBlockOptions{
|
||||
Entrypoint: "cli",
|
||||
OmitCCH: true,
|
||||
})
|
||||
want := "x-anthropic-billing-header: cc_version=2.1.104.abc; cc_entrypoint=cli;"
|
||||
if got != want {
|
||||
t.Fatalf("buildAttributionBlock() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectAttributionBlock_DisabledByEnv(t *testing.T) {
|
||||
t.Setenv("CLAUDE_CODE_ATTRIBUTION_HEADER", "false")
|
||||
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
got := injectAttributionBlock(body, "2.1.104", attributionBlockOptions{})
|
||||
if string(got) != string(body) {
|
||||
t.Fatalf("injectAttributionBlock() should keep body unchanged when attribution disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldOmitAttributionCCH(t *testing.T) {
|
||||
if !shouldOmitAttributionCCH(&Account{Type: AccountTypeBedrock}, "") {
|
||||
t.Fatal("expected bedrock account to omit cch")
|
||||
}
|
||||
if !shouldOmitAttributionCCH(&Account{Extra: map[string]any{"provider": "mantle"}}, "") {
|
||||
t.Fatal("expected mantle provider to omit cch")
|
||||
}
|
||||
if shouldOmitAttributionCCH(&Account{Type: AccountTypeOAuth}, "oauth") {
|
||||
t.Fatal("expected oauth account to keep cch")
|
||||
}
|
||||
}
|
||||
@ -1,47 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
claude "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
)
|
||||
|
||||
func applyClaudeRuntimeOptionalHeaders(req *http.Request) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
for key, value := range claude.OptionalAPIHeaders() {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
setHeaderRaw(req.Header, resolveWireCasing(key), value)
|
||||
}
|
||||
}
|
||||
|
||||
func attributionOptionsForRequest(account *Account, tokenType string) attributionBlockOptions {
|
||||
return attributionBlockOptions{
|
||||
Entrypoint: claude.CurrentEntrypoint(),
|
||||
Workload: claude.CurrentWorkload(),
|
||||
OmitCCH: shouldOmitAttributionCCH(account, tokenType),
|
||||
}
|
||||
}
|
||||
|
||||
func shouldOmitAttributionCCH(account *Account, tokenType string) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(tokenType), "bedrock") {
|
||||
return true
|
||||
}
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Type == AccountTypeBedrock {
|
||||
return true
|
||||
}
|
||||
for _, key := range []string{"provider", "upstream_provider"} {
|
||||
switch strings.ToLower(strings.TrimSpace(account.GetExtraString(key))) {
|
||||
case "bedrock", "anthropicaws", "anthropic_aws", "mantle":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -901,9 +901,6 @@ func sanitizeSystemText(text string) string {
|
||||
"You are OpenCode, the best coding agent on the planet.",
|
||||
strings.TrimSpace(claudeCodeSystemPrompt),
|
||||
)
|
||||
// Normalize environment block fields (Platform/Shell/OS Version/Working directory)
|
||||
// to canonical values so different client machines don't create fingerprint divergence.
|
||||
text = NormalizeSystemPromptEnv(text)
|
||||
return text
|
||||
}
|
||||
|
||||
@ -4230,22 +4227,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
}
|
||||
|
||||
// 注入 x-anthropic-billing-header attribution block(所有 OAuth 账号)
|
||||
// 真实 CLI 在 system prompt 的第一个 text block 注入此 billing header。
|
||||
// 用于 Anthropic 后端路由和验证。
|
||||
// 跳过条件:system 已被 rewriteSystemForNonClaudeCode 重写(claudeCodeSystemPrompt 在 system[0]);
|
||||
// 注入会将其移到 system[1],破坏伪装结构及 system[0] 断言。
|
||||
if account.IsOAuth() && !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemRewritten {
|
||||
// 获取 CLI 版本:优先用指纹中的版本,回退到默认
|
||||
attrCLIVersion := claude.DefaultCLIVersion
|
||||
if fp := getHeaderRaw(c.Request.Header, "User-Agent"); fp != "" {
|
||||
if v := ExtractCLIVersion(fp); v != "" {
|
||||
attrCLIVersion = v
|
||||
}
|
||||
}
|
||||
body = injectAttributionBlock(body, attrCLIVersion, attributionOptionsForRequest(account, "oauth"))
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
@ -5926,35 +5907,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 1. 客户端已提供 → 同步为 body 中 metadata.user_id 的 session_id
|
||||
// 2. 客户端未提供(mimic 模式)→ 生成确定性 per-account session UUID
|
||||
// 真实 CLI 每个请求都携带此 header(per-process UUID)。
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
} else if tokenType == "oauth" {
|
||||
// mimic 模式:生成 session-id
|
||||
var sessionID string
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
sessionID = parsed.SessionID
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
salt := ""
|
||||
if s.cfg != nil {
|
||||
salt = s.cfg.Gateway.InstanceSalt
|
||||
}
|
||||
sessionID = generateSessionIDForAccount(salt, account.ID)
|
||||
}
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
|
||||
}
|
||||
|
||||
// x-client-request-id: 真实 CLI 每个请求生成新 UUID(仅 1P)。
|
||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
applyClaudeRuntimeOptionalHeaders(req)
|
||||
|
||||
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
|
||||
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
|
||||
@ -8984,35 +8949,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// X-Claude-Code-Session-Id 头处理(count_tokens 路径)
|
||||
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
|
||||
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
|
||||
}
|
||||
}
|
||||
} else if tokenType == "oauth" {
|
||||
var sessionID string
|
||||
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
|
||||
if parsed := ParseMetadataUserID(uid); parsed != nil {
|
||||
sessionID = parsed.SessionID
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
salt := ""
|
||||
if s.cfg != nil {
|
||||
salt = s.cfg.Gateway.InstanceSalt
|
||||
}
|
||||
sessionID = generateSessionIDForAccount(salt, account.ID)
|
||||
}
|
||||
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", sessionID)
|
||||
}
|
||||
|
||||
// x-client-request-id(count_tokens 路径)
|
||||
if getHeaderRaw(req.Header, "x-client-request-id") == "" && tokenType == "oauth" {
|
||||
setHeaderRaw(req.Header, "x-client-request-id", uuid.New().String())
|
||||
}
|
||||
applyClaudeRuntimeOptionalHeaders(req)
|
||||
|
||||
if c != nil && tokenType == "oauth" {
|
||||
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
|
||||
|
||||
@ -13,7 +13,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
package service
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
|
||||
// ==============================================================
|
||||
// antigravity — identity_service 扩展
|
||||
//
|
||||
// 此文件包含 Antigravity fork 对 IdentityService 的扩展,
|
||||
// 新增了实例级隔离盐值和指纹默认值覆盖功能。
|
||||
//
|
||||
// 对上游文件 identity_service.go 的最小化改动:
|
||||
// - defaultFingerprint 版本号更新
|
||||
// - IdentityService struct 新增 instanceSalt 字段
|
||||
// ==============================================================
|
||||
|
||||
// ApplyDefaultFingerprintOverrides 用配置覆盖 identity_service 的默认指纹
|
||||
// 允许不同部署实例设置不同的 CLI/SDK 版本号,避免所有实例指纹相同
|
||||
func ApplyDefaultFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch string) {
|
||||
claude.ApplyFingerprintOverrides(cliVersion, pkgVersion, runtimeVersion, os_, arch)
|
||||
defaultFingerprint = defaultIdentityFingerprint()
|
||||
}
|
||||
|
||||
// NewIdentityServiceWithSalt 创建带实例盐值的 IdentityService
|
||||
// 实例盐值用于 user_id 重写时的 session hash 混淆,
|
||||
// 使不同 sub2api 实例对相同输入产生不同的 hash 输出,增加隔离性
|
||||
func NewIdentityServiceWithSalt(cache IdentityCache, salt string) *IdentityService {
|
||||
return &IdentityService{cache: cache, instanceSalt: salt}
|
||||
}
|
||||
@ -1,530 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CascadeSession 代表一个 Cascade Agent 会话
|
||||
type CascadeSession struct {
|
||||
ID string
|
||||
ModelName string
|
||||
Messages []map[string]interface{} // {role, content}
|
||||
Metadata map[string]string // 设备指纹、User-Agent 等
|
||||
Token string // OAuth token
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
// LanguageServerService 业务逻辑层
|
||||
// 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API
|
||||
type LanguageServerService struct {
|
||||
// 会话管理
|
||||
cascadeSessions map[string]*CascadeSession
|
||||
sessionMutex sync.RWMutex
|
||||
|
||||
// 上游 HTTP 服务(用于发送请求)
|
||||
httpUpstream HTTPUpstream
|
||||
|
||||
// Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新)
|
||||
antigravitySvc *AntigravityGatewayService
|
||||
accountRepo AccountRepository
|
||||
|
||||
// 日志
|
||||
logger *slog.Logger
|
||||
|
||||
// 改进 1: 速率限制 (令牌桶)
|
||||
// 限制并发消息处理数量,保护上游 API
|
||||
rateLimiter chan struct{}
|
||||
|
||||
// 改进 3: 会话过期时间 (秒)
|
||||
sessionTTLSeconds int64
|
||||
|
||||
// 改进 3: 定期清理后台任务
|
||||
cleanupTicker *time.Ticker
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
func NewLanguageServerService(
|
||||
logger *slog.Logger,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravitySvc *AntigravityGatewayService,
|
||||
accountRepo AccountRepository,
|
||||
) *LanguageServerService {
|
||||
svc := &LanguageServerService{
|
||||
cascadeSessions: make(map[string]*CascadeSession),
|
||||
logger: logger,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravitySvc: antigravitySvc,
|
||||
accountRepo: accountRepo,
|
||||
rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息
|
||||
sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 改进 3: 启动后台清理任务
|
||||
svc.startSessionCleanup()
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// startSessionCleanup 启动会话定期清理任务
|
||||
func (svc *LanguageServerService) startSessionCleanup() {
|
||||
svc.cleanupTicker = time.NewTicker(1 * time.Minute)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-svc.cleanupTicker.C:
|
||||
svc.cleanupExpiredSessions()
|
||||
case <-svc.stopCleanup:
|
||||
svc.cleanupTicker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanupExpiredSessions 清理过期的会话
|
||||
func (svc *LanguageServerService) cleanupExpiredSessions() {
|
||||
now := getCurrentTimeMS()
|
||||
ttlMs := svc.sessionTTLSeconds * 1000
|
||||
|
||||
svc.sessionMutex.Lock()
|
||||
defer svc.sessionMutex.Unlock()
|
||||
|
||||
deletedCount := 0
|
||||
for id, session := range svc.cascadeSessions {
|
||||
if now-session.CreatedAt > ttlMs {
|
||||
delete(svc.cascadeSessions, id)
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
svc.logger.Info("expired sessions cleaned up",
|
||||
"deleted_count", deletedCount,
|
||||
"remaining_sessions", len(svc.cascadeSessions),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 优雅关闭服务
|
||||
func (svc *LanguageServerService) Stop() {
|
||||
select {
|
||||
case svc.stopCleanup <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// SetSessionTTL sets the session TTL for testing purposes
|
||||
func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) {
|
||||
svc.sessionTTLSeconds = ttlSeconds
|
||||
}
|
||||
|
||||
// GetCascadeSessions returns the current cascade sessions map (for testing)
|
||||
func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession {
|
||||
svc.sessionMutex.RLock()
|
||||
defer svc.sessionMutex.RUnlock()
|
||||
return svc.cascadeSessions
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cascade 业务逻辑
|
||||
// ============================================================================
|
||||
|
||||
// StartCascade 启动新的 Cascade Agent 会话
|
||||
func (svc *LanguageServerService) StartCascade(
|
||||
ctx context.Context,
|
||||
model string,
|
||||
systemPrompt string,
|
||||
metadata map[string]string,
|
||||
token string,
|
||||
) (string, error) {
|
||||
// 1. 验证输入
|
||||
if model == "" {
|
||||
return "", fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("oauth token is required")
|
||||
}
|
||||
|
||||
// 2. 生成会话 ID
|
||||
sessionID := uuid.New().String()
|
||||
|
||||
// 3. 创建会话
|
||||
session := &CascadeSession{
|
||||
ID: sessionID,
|
||||
ModelName: model,
|
||||
Messages: make([]map[string]interface{}, 0),
|
||||
Metadata: metadata,
|
||||
Token: token,
|
||||
CreatedAt: getCurrentTimeMS(),
|
||||
}
|
||||
|
||||
// 如果提供了系统提示,添加为初始消息
|
||||
if systemPrompt != "" {
|
||||
session.Messages = append(session.Messages, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
// 4. 保存会话
|
||||
svc.sessionMutex.Lock()
|
||||
svc.cascadeSessions[sessionID] = session
|
||||
svc.sessionMutex.Unlock()
|
||||
|
||||
svc.logger.Info("cascade session started",
|
||||
"session_id", sessionID,
|
||||
"model", model,
|
||||
"has_system_prompt", systemPrompt != "")
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
// SendUserMessage 发送用户消息到 Cascade
|
||||
// 返回流式更新通道
|
||||
func (svc *LanguageServerService) SendUserMessage(
|
||||
ctx context.Context,
|
||||
cascadeID string,
|
||||
userMessage string,
|
||||
token string,
|
||||
) (<-chan interface{}, error) {
|
||||
// 改进 1: 获取速率限制令牌
|
||||
select {
|
||||
case svc.rateLimiter <- struct{}{}:
|
||||
// 获得令牌,继续
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("context cancelled")
|
||||
default:
|
||||
// 没有令牌,需要等待
|
||||
select {
|
||||
case svc.rateLimiter <- struct{}{}:
|
||||
// 获得令牌
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("context cancelled while waiting for rate limit")
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, fmt.Errorf("rate limit timeout: too many concurrent messages")
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 获取会话
|
||||
svc.sessionMutex.RLock()
|
||||
session, exists := svc.cascadeSessions[cascadeID]
|
||||
svc.sessionMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// 释放令牌
|
||||
<-svc.rateLimiter
|
||||
return nil, fmt.Errorf("cascade session not found: %s", cascadeID)
|
||||
}
|
||||
|
||||
// 2. 验证 token
|
||||
if token != session.Token {
|
||||
// 释放令牌
|
||||
<-svc.rateLimiter
|
||||
return nil, fmt.Errorf("invalid token for session")
|
||||
}
|
||||
|
||||
// 改进 2: 并发安全的消息追加(深拷贝消息列表)
|
||||
svc.sessionMutex.Lock()
|
||||
newMessages := make([]map[string]interface{}, len(session.Messages)+1)
|
||||
copy(newMessages, session.Messages)
|
||||
newMessages[len(newMessages)-1] = map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": userMessage,
|
||||
}
|
||||
session.Messages = newMessages
|
||||
svc.sessionMutex.Unlock()
|
||||
|
||||
// 4. 创建响应通道
|
||||
updateChan := make(chan interface{}, 100)
|
||||
|
||||
// 5. 启动后台 goroutine 处理 API 调用
|
||||
go func() {
|
||||
defer func() {
|
||||
// 关闭通道
|
||||
close(updateChan)
|
||||
// 改进 1: 释放速率限制令牌
|
||||
<-svc.rateLimiter
|
||||
}()
|
||||
|
||||
// 调用上游 API(关键!这里需要伪装)
|
||||
svc.callUpstreamAPI(ctx, session, updateChan)
|
||||
}()
|
||||
|
||||
svc.logger.Info("user message sent to cascade",
|
||||
"session_id", cascadeID,
|
||||
"message_length", len(userMessage),
|
||||
"concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数
|
||||
)
|
||||
|
||||
return updateChan, nil
|
||||
}
|
||||
|
||||
// CancelCascade 取消 Cascade 会话
|
||||
func (svc *LanguageServerService) CancelCascade(
|
||||
ctx context.Context,
|
||||
cascadeID string,
|
||||
) error {
|
||||
svc.sessionMutex.Lock()
|
||||
_, exists := svc.cascadeSessions[cascadeID]
|
||||
svc.sessionMutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("cascade session not found: %s", cascadeID)
|
||||
}
|
||||
|
||||
// TODO: 取消正在进行的 API 调用
|
||||
|
||||
svc.logger.Info("cascade cancelled", "session_id", cascadeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 模型配置
|
||||
// ============================================================================
|
||||
|
||||
// ModelConfig 模型配置
|
||||
type ModelConfig struct {
|
||||
Name string
|
||||
DisplayName string
|
||||
MaxTokens int
|
||||
SupportsThinking bool
|
||||
ThinkingBudget int
|
||||
SupportsImages bool
|
||||
Provider string
|
||||
}
|
||||
|
||||
// GetAvailableModels 获取可用模型列表
|
||||
func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) {
|
||||
models := []ModelConfig{
|
||||
{
|
||||
Name: "claude-opus-4-7",
|
||||
DisplayName: "Claude Opus 4.7",
|
||||
MaxTokens: 200000,
|
||||
SupportsThinking: true,
|
||||
ThinkingBudget: 32000,
|
||||
SupportsImages: true,
|
||||
Provider: "anthropic",
|
||||
},
|
||||
{
|
||||
Name: "claude-sonnet-4-7",
|
||||
DisplayName: "Claude Sonnet 4.7",
|
||||
MaxTokens: 200000,
|
||||
SupportsThinking: true,
|
||||
ThinkingBudget: 16000,
|
||||
SupportsImages: true,
|
||||
Provider: "anthropic",
|
||||
},
|
||||
{
|
||||
Name: "claude-opus-4-6",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
MaxTokens: 200000,
|
||||
SupportsThinking: true,
|
||||
ThinkingBudget: 32000,
|
||||
SupportsImages: true,
|
||||
Provider: "anthropic",
|
||||
},
|
||||
{
|
||||
Name: "claude-sonnet-4-6",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
MaxTokens: 200000,
|
||||
SupportsThinking: false,
|
||||
SupportsImages: true,
|
||||
Provider: "anthropic",
|
||||
},
|
||||
{
|
||||
Name: "claude-haiku-4-5",
|
||||
DisplayName: "Claude Haiku 4.5",
|
||||
MaxTokens: 200000,
|
||||
SupportsThinking: false,
|
||||
SupportsImages: true,
|
||||
Provider: "anthropic",
|
||||
},
|
||||
{
|
||||
Name: "gemini-3-pro",
|
||||
DisplayName: "Gemini 3 Pro",
|
||||
MaxTokens: 128000,
|
||||
SupportsThinking: false,
|
||||
SupportsImages: true,
|
||||
Provider: "google",
|
||||
},
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 状态查询
|
||||
// ============================================================================
|
||||
|
||||
// GetStatus 获取服务状态
|
||||
func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) {
|
||||
// TODO: 检查上游 API 连接状态
|
||||
return "running", nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 内部方法
|
||||
// ============================================================================
|
||||
|
||||
// callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。
|
||||
// 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。
|
||||
func (svc *LanguageServerService) callUpstreamAPI(
|
||||
ctx context.Context,
|
||||
session *CascadeSession,
|
||||
updateChan chan<- interface{},
|
||||
) {
|
||||
if svc.antigravitySvc == nil || svc.accountRepo == nil {
|
||||
updateChan <- map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": "antigravity gateway not configured",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 选取第一个可用的 Antigravity 账号
|
||||
accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity)
|
||||
if err != nil || len(accounts) == 0 {
|
||||
svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err)
|
||||
updateChan <- map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": "no antigravity accounts available",
|
||||
}
|
||||
return
|
||||
}
|
||||
account := &accounts[0]
|
||||
|
||||
// 2. 准备 Claude 格式请求体
|
||||
requestBody := map[string]interface{}{
|
||||
"model": session.ModelName,
|
||||
"messages": session.Messages,
|
||||
"stream": true,
|
||||
"max_tokens": 8192,
|
||||
}
|
||||
bodyJSON, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err)
|
||||
updateChan <- map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": "failed to prepare request",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID)
|
||||
|
||||
// 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试)
|
||||
respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON)
|
||||
if err != nil {
|
||||
svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err)
|
||||
updateChan <- map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": fmt.Sprintf("upstream request failed: %v", err),
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() { _ = respBody.Close() }()
|
||||
|
||||
// 4. 处理错误响应
|
||||
if statusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20))
|
||||
svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body))
|
||||
updateChan <- map[string]interface{}{
|
||||
"type": "error",
|
||||
"status_code": statusCode,
|
||||
"error": string(body),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 流式转发响应
|
||||
svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan)
|
||||
}
|
||||
|
||||
// streamUpstreamResponse 处理上游 SSE 流式响应
|
||||
func (svc *LanguageServerService) streamUpstreamResponse(
|
||||
ctx context.Context,
|
||||
sessionID string,
|
||||
body io.ReadCloser,
|
||||
updateChan chan<- interface{},
|
||||
) {
|
||||
scanner := bufio.NewScanner(body)
|
||||
// 设置合理的缓冲区大小
|
||||
scanner.Buffer(make([]byte, 64*1024), 512*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
svc.logger.Info("streaming cancelled", "session_id", sessionID)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// 跳过空行
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过注释行
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析 SSE 格式 (data: {...})
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
eventData := strings.TrimPrefix(line, "data:")
|
||||
eventData = strings.TrimSpace(eventData)
|
||||
|
||||
// 解析 JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||||
svc.logger.Debug("failed to parse event",
|
||||
"session_id", sessionID,
|
||||
"error", err,
|
||||
"data", eventData,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送事件到客户端通道
|
||||
select {
|
||||
case updateChan <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
svc.logger.Warn("channel send timeout",
|
||||
"session_id", sessionID,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
svc.logger.Error("scanning upstream response failed",
|
||||
"session_id", sessionID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentTimeMS 获取当前时间戳(毫秒)
|
||||
func getCurrentTimeMS() int64 {
|
||||
return time.Now().UnixMilli()
|
||||
}
|
||||
@ -1,353 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
connect "connectrpc.com/connect"
|
||||
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pb"
|
||||
"github.com/Wei-Shaw/sub2api/internal/gen/language_server_pbconnect"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const upstreamLSRPCBaseURL = "https://cloudcode-pa.googleapis.com"
|
||||
|
||||
// LSRPCHandler implements LanguageServerServiceHandler by proxying to the real upstream
|
||||
// lsrpc service using OAuth tokens obtained from AntigravityGatewayService.
|
||||
// File RPCs (ReadFile/WriteFile/ReadDir/etc.) operate on the local filesystem.
|
||||
type LSRPCHandler struct {
|
||||
language_server_pbconnect.UnimplementedLanguageServerServiceHandler
|
||||
|
||||
antigravitySvc *AntigravityGatewayService
|
||||
accountRepo AccountRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLSRPCHandler creates a new LSRPCHandler.
|
||||
func NewLSRPCHandler(
|
||||
antigravitySvc *AntigravityGatewayService,
|
||||
accountRepo AccountRepository,
|
||||
logger *slog.Logger,
|
||||
) *LSRPCHandler {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &LSRPCHandler{
|
||||
antigravitySvc: antigravitySvc,
|
||||
accountRepo: accountRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// upstreamClient creates a connectrpc client to the real lsrpc upstream,
|
||||
// authenticated with the OAuth token from the given account.
|
||||
func (h *LSRPCHandler) upstreamClient(ctx context.Context) (language_server_pbconnect.LanguageServerServiceClient, error) {
|
||||
accounts, err := h.accountRepo.ListByPlatform(ctx, PlatformAntigravity)
|
||||
if err != nil || len(accounts) == 0 {
|
||||
return nil, fmt.Errorf("no antigravity accounts available: %w", err)
|
||||
}
|
||||
account := &accounts[0]
|
||||
|
||||
tokenProvider := h.antigravitySvc.GetTokenProvider()
|
||||
if tokenProvider == nil {
|
||||
return nil, fmt.Errorf("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: &bearerTransport{
|
||||
base: http.DefaultTransport,
|
||||
token: accessToken,
|
||||
},
|
||||
}
|
||||
|
||||
client := language_server_pbconnect.NewLanguageServerServiceClient(
|
||||
httpClient,
|
||||
upstreamLSRPCBaseURL,
|
||||
connect.WithGRPC(),
|
||||
)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// bearerTransport injects Authorization: Bearer <token> into every request.
|
||||
type bearerTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
}
|
||||
|
||||
func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
clone := req.Clone(req.Context())
|
||||
clone.Header.Set("Authorization", "Bearer "+t.token)
|
||||
return t.base.RoundTrip(clone)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cascade RPCs — proxied to real upstream
|
||||
// ============================================================================
|
||||
|
||||
func (h *LSRPCHandler) StartCascade(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.StartCascadeRequest],
|
||||
) (*connect.Response[language_server_pb.StartCascadeResponse], error) {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnavailable, err)
|
||||
}
|
||||
return client.StartCascade(ctx, req)
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) SendUserCascadeMessage(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.SendUserCascadeMessageRequest],
|
||||
stream *connect.ServerStream[language_server_pb.CascadeReactiveUpdate],
|
||||
) error {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeUnavailable, err)
|
||||
}
|
||||
|
||||
upstreamStream, err := client.SendUserCascadeMessage(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer upstreamStream.Close()
|
||||
|
||||
for upstreamStream.Receive() {
|
||||
if err := stream.Send(upstreamStream.Msg()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return upstreamStream.Err()
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) CancelCascadeInvocation(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.CancelCascadeInvocationRequest],
|
||||
) (*connect.Response[language_server_pb.CancelCascadeInvocationResponse], error) {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnavailable, err)
|
||||
}
|
||||
return client.CancelCascadeInvocation(ctx, req)
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) AcknowledgeCascadeCodeEdit(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.AcknowledgeCascadeCodeEditRequest],
|
||||
) (*connect.Response[language_server_pb.AcknowledgeCascadeCodeEditResponse], error) {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeUnavailable, err)
|
||||
}
|
||||
return client.AcknowledgeCascadeCodeEdit(ctx, req)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model config RPCs — proxied to real upstream
|
||||
// ============================================================================
|
||||
|
||||
func (h *LSRPCHandler) GetCascadeModelConfigs(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.GetCascadeModelConfigsRequest],
|
||||
) (*connect.Response[language_server_pb.GetCascadeModelConfigsResponse], error) {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
// Fall back to static list when upstream unavailable.
|
||||
return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{
|
||||
Models: staticCascadeModels(),
|
||||
}), nil
|
||||
}
|
||||
resp, err := client.GetCascadeModelConfigs(ctx, req)
|
||||
if err != nil {
|
||||
return connect.NewResponse(&language_server_pb.GetCascadeModelConfigsResponse{
|
||||
Models: staticCascadeModels(),
|
||||
}), nil
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) GetCommandModelConfigs(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.GetCommandModelConfigsRequest],
|
||||
) (*connect.Response[language_server_pb.GetCommandModelConfigsResponse], error) {
|
||||
client, err := h.upstreamClient(ctx)
|
||||
if err != nil {
|
||||
return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{
|
||||
Models: staticCascadeModels(),
|
||||
}), nil
|
||||
}
|
||||
resp, err := client.GetCommandModelConfigs(ctx, req)
|
||||
if err != nil {
|
||||
return connect.NewResponse(&language_server_pb.GetCommandModelConfigsResponse{
|
||||
Models: staticCascadeModels(),
|
||||
}), nil
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// staticCascadeModels returns a hard-coded model list as fallback.
|
||||
func staticCascadeModels() []*language_server_pb.ModelConfig {
|
||||
return []*language_server_pb.ModelConfig{
|
||||
{Name: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"},
|
||||
{Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic"},
|
||||
{Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"},
|
||||
{Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsImages: true, Provider: "anthropic"},
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// File RPCs — local filesystem implementation
|
||||
// ============================================================================
|
||||
|
||||
func (h *LSRPCHandler) ReadFile(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.ReadFileRequest],
|
||||
) (*connect.Response[language_server_pb.ReadFileResponse], error) {
|
||||
path := req.Msg.GetPath()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("file not found: %s", path))
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&language_server_pb.ReadFileResponse{
|
||||
Content: string(data),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) WriteFile(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.WriteFileRequest],
|
||||
) (*connect.Response[language_server_pb.WriteFileResponse], error) {
|
||||
path := req.Msg.GetPath()
|
||||
if req.Msg.GetCreateParent() {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(req.Msg.GetContent()), 0o644); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&language_server_pb.WriteFileResponse{}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) ReadDir(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.ReadDirRequest],
|
||||
) (*connect.Response[language_server_pb.ReadDirResponse], error) {
|
||||
path := req.Msg.GetPath()
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("directory not found: %s", path))
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
files := make([]*language_server_pb.FileInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
files = append(files, fileInfoFromOS(entry.Name(), info))
|
||||
}
|
||||
return connect.NewResponse(&language_server_pb.ReadDirResponse{
|
||||
Files: files,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) DeleteFileOrDirectory(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.DeleteFileOrDirectoryRequest],
|
||||
) (*connect.Response[language_server_pb.DeleteFileOrDirectoryResponse], error) {
|
||||
path := req.Msg.GetPath()
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&language_server_pb.DeleteFileOrDirectoryResponse{}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) StatUri(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.StatUriRequest],
|
||||
) (*connect.Response[language_server_pb.StatUriResponse], error) {
|
||||
path := req.Msg.GetPath()
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("path not found: %s", path))
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&language_server_pb.StatUriResponse{
|
||||
FileInfo: fileInfoFromOS(info.Name(), info),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) WatchDirectory(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.WatchDirectoryRequest],
|
||||
stream *connect.ServerStream[language_server_pb.WatchDirectoryResponse],
|
||||
) error {
|
||||
// Block until context is cancelled — real FS watching requires fsnotify which
|
||||
// is not in the dependency graph yet. This satisfies the interface contract
|
||||
// without crashing; the client will get an EOF when the connection closes.
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Health RPCs
|
||||
// ============================================================================
|
||||
|
||||
func (h *LSRPCHandler) Heartbeat(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.HeartbeatRequest],
|
||||
) (*connect.Response[language_server_pb.HeartbeatResponse], error) {
|
||||
return connect.NewResponse(&language_server_pb.HeartbeatResponse{
|
||||
Healthy: true,
|
||||
Version: "sub2api",
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (h *LSRPCHandler) GetStatus(
|
||||
ctx context.Context,
|
||||
req *connect.Request[language_server_pb.GetStatusRequest],
|
||||
) (*connect.Response[language_server_pb.GetStatusResponse], error) {
|
||||
return connect.NewResponse(&language_server_pb.GetStatusResponse{
|
||||
Status: "running",
|
||||
Version: antigravity.BaseURL,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
func fileInfoFromOS(name string, info fs.FileInfo) *language_server_pb.FileInfo {
|
||||
t := language_server_pb.FileInfo_FILE
|
||||
if info.IsDir() {
|
||||
t = language_server_pb.FileInfo_DIRECTORY
|
||||
} else if info.Mode()&os.ModeSymlink != 0 {
|
||||
t = language_server_pb.FileInfo_SYMLINK
|
||||
}
|
||||
return &language_server_pb.FileInfo{
|
||||
Path: name,
|
||||
Type: t,
|
||||
Size: info.Size(),
|
||||
ModifiedTime: timestamppb.New(info.ModTime()),
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,10 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
import "github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -8,8 +8,6 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"image"
|
||||
"image/color"
|
||||
stddraw "image/draw"
|
||||
@ -24,6 +22,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
xdraw "golang.org/x/image/draw"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
@ -80,10 +80,10 @@ func TestInjectModelIdentity(t *testing.T) {
|
||||
wantInjected bool
|
||||
}{
|
||||
{
|
||||
name: "anthropic model without system",
|
||||
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
|
||||
meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"},
|
||||
modelKey: "claude-sonnet-4.6",
|
||||
name: "anthropic model without system",
|
||||
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
|
||||
meta: &windsurf.ModelMeta{Name: "claude-sonnet-4.6", Provider: "anthropic"},
|
||||
modelKey: "claude-sonnet-4.6",
|
||||
wantInjected: true,
|
||||
},
|
||||
{
|
||||
@ -111,10 +111,10 @@ func TestInjectModelIdentity(t *testing.T) {
|
||||
wantInjected: false,
|
||||
},
|
||||
{
|
||||
name: "openai model without system",
|
||||
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
|
||||
meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"},
|
||||
modelKey: "gpt-4o",
|
||||
name: "openai model without system",
|
||||
messages: []windsurf.ChatMessage{{Role: "user", Content: "hi"}},
|
||||
meta: &windsurf.ModelMeta{Name: "gpt-4o", Provider: "openai"},
|
||||
modelKey: "gpt-4o",
|
||||
wantInjected: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -609,13 +609,13 @@ type windsurfRequestTool struct {
|
||||
// ---- Helper functions (prefixed to avoid collision with windsurf_gateway_handler.go) ----
|
||||
|
||||
type windsurfContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
// Source 来自 Anthropic image block:{type:"base64", media_type:"image/png", data:"..."}
|
||||
Source *windsurfContentImageSource `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
@ -494,7 +494,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewPaymentService,
|
||||
ProvidePaymentOrderExpiryService,
|
||||
ProvideBalanceNotifyService,
|
||||
ProvideLanguageServerService,
|
||||
ProvideWindsurfAuthService,
|
||||
ProvideWindsurfLSService,
|
||||
ProvideWindsurfChatService,
|
||||
@ -507,11 +506,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelMonitorRequestTemplateService,
|
||||
)
|
||||
|
||||
// ProvideLanguageServerService creates LanguageServerService with injected dependencies
|
||||
func ProvideLanguageServerService(httpUpstream HTTPUpstream, antigravitySvc *AntigravityGatewayService, accountRepo AccountRepository) *LanguageServerService {
|
||||
return NewLanguageServerService(slog.Default(), httpUpstream, antigravitySvc, accountRepo)
|
||||
}
|
||||
|
||||
// ProvideWindsurfAuthService creates WindsurfAuthService from the main config.
|
||||
func ProvideWindsurfAuthService(cfg *config.Config, accountRepo AccountRepository, proxyRepo ProxyRepository, adminSvc AdminService) *WindsurfAuthService {
|
||||
if !cfg.Windsurf.Enabled {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user