package main import ( "context" "flag" "fmt" "log" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) // TestScenario 定义一个测试场景 type TestScenario struct { name string description string testFunc func(ctx context.Context, token, projectID string) (bool, string) } var scenarios []TestScenario func init() { scenarios = []TestScenario{ { name: "single_request", description: "单次请求 - 检查是否立即成功", testFunc: testSingleRequest, }, { name: "sequential_requests", description: "顺序发送 10 个请求 - 找到稳定点", testFunc: testSequentialRequests, }, { name: "concurrent_requests", description: "并发发送 5 个请求 - 检查并发初始化行为", testFunc: testConcurrentRequests, }, { name: "warmup_then_request", description: "预热(模型列表请求) + 业务请求 - 验证预热效果", testFunc: testWarmupThenRequest, }, { name: "delayed_request", description: "延迟 5 秒后请求 - 检查账号初始化时间", testFunc: testDelayedRequest, }, } } // testSingleRequest 单次请求 func testSingleRequest(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } start := time.Now() resp, _, err := client.FetchAvailableModels(ctx, token, projectID) elapsed := time.Since(start) if err != nil { return false, fmt.Sprintf("请求失败 (%v): %v", elapsed, err) } if resp == nil { return false, fmt.Sprintf("响应为空 (%v)", elapsed) } return true, fmt.Sprintf("✓ 单次请求成功 - 耗时 %v", elapsed) } // testSequentialRequests 顺序发送多个请求 func testSequentialRequests(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } var firstFailIdx = -1 var firstSuccessIdx = -1 var timings []time.Duration for i := 0; i < 10; i++ { start := time.Now() resp, _, err := client.FetchAvailableModels(ctx, token, projectID) elapsed := time.Since(start) timings = append(timings, elapsed) success := err == nil && resp != nil fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", i+1, elapsed, map[bool]string{true: "✓", false: "✗"}[success]) if !success && firstFailIdx == -1 { firstFailIdx = i } if success && firstSuccessIdx == -1 { firstSuccessIdx = i } } var report string if firstSuccessIdx == -1 { report = "✗ 全部失败" } else if firstSuccessIdx == 0 { report = fmt.Sprintf("✓ 首次即成功 (耗时 %v)", timings[0]) } else { report = fmt.Sprintf("⚠ 第 %d 次才成功 (失败 %d 次), 首次耗时 %v", firstSuccessIdx+1, firstSuccessIdx, timings[firstSuccessIdx]) } return firstSuccessIdx >= 0, report } // testConcurrentRequests 并发请求 func testConcurrentRequests(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } var wg sync.WaitGroup results := make([]bool, 5) timings := make([]time.Duration, 5) mu := sync.Mutex{} for i := 0; i < 5; i++ { wg.Add(1) go func(idx int) { defer wg.Done() start := time.Now() resp, _, err := client.FetchAvailableModels(ctx, token, projectID) elapsed := time.Since(start) mu.Lock() results[idx] = err == nil && resp != nil timings[idx] = elapsed mu.Unlock() fmt.Printf(" [%d] 耗时: %6v, 状态: %v\n", idx+1, elapsed, map[bool]string{true: "✓", false: "✗"}[results[idx]]) }(i) } wg.Wait() successCount := 0 for _, ok := range results { if ok { successCount++ } } return successCount > 0, fmt.Sprintf("%d/5 并发请求成功", successCount) } // testWarmupThenRequest 预热测试 func testWarmupThenRequest(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } // 第 1 步:预热 - 调用 LoadCodeAssist(获取项目信息) fmt.Println(" [Warmup] 调用 LoadCodeAssist 预热...") warmupStart := time.Now() _, _, warmupErr := client.LoadCodeAssist(ctx, token) warmupElapsed := time.Since(warmupStart) fmt.Printf(" [Warmup] 耗时 %v, 状态: %v\n", warmupElapsed, map[bool]string{true: "✓", false: "✗"}[warmupErr == nil]) // 第 2 步:实际请求 fmt.Println(" [Request] 发送业务请求...") reqStart := time.Now() resp, _, err := client.FetchAvailableModels(ctx, token, projectID) reqElapsed := time.Since(reqStart) success := err == nil && resp != nil fmt.Printf(" [Request] 耗时 %v, 状态: %v\n", reqElapsed, map[bool]string{true: "✓", false: "✗"}[success]) return success, fmt.Sprintf("预热 %v + 请求 %v = 总耗时 %v", warmupElapsed, reqElapsed, warmupElapsed+reqElapsed) } // testDelayedRequest 延迟请求 func testDelayedRequest(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } fmt.Println(" 等待 5 秒...") time.Sleep(5 * time.Second) start := time.Now() resp, _, err := client.FetchAvailableModels(ctx, token, projectID) elapsed := time.Since(start) success := err == nil && resp != nil return success, fmt.Sprintf("延迟 5s 后请求 - 耗时 %v, 状态: %v", elapsed, map[bool]string{true: "✓", false: "✗"}[success]) } // testOAuthTokenRefresh OAuth Token 刷新测试 func testOAuthTokenRefresh(ctx context.Context, refreshToken string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } start := time.Now() tokenInfo, err := client.RefreshToken(ctx, refreshToken, false) elapsed := time.Since(start) if err != nil { return false, fmt.Sprintf("Token 刷新失败 (%v): %v", elapsed, err) } return true, fmt.Sprintf("✓ Token 刷新成功 - 耗时 %v, 新 Token 有效期: %d 秒", elapsed, tokenInfo.ExpiresIn) } // testAccountInitializationWarmup 账号初始化预热 func testAccountInitializationWarmup(ctx context.Context, token, projectID string) (bool, string) { client, err := antigravity.NewClient("") if err != nil { return false, fmt.Sprintf("创建客户端失败: %v", err) } fmt.Println(" 执行完整的账号初始化流程...") // 1. GetUserInfo fmt.Println(" 1. GetUserInfo...") start := time.Now() _, err1 := client.GetUserInfo(ctx, token) fmt.Printf(" 耗时: %v\n", time.Since(start)) // 2. LoadCodeAssist fmt.Println(" 2. LoadCodeAssist...") start = time.Now() _, _, err2 := client.LoadCodeAssist(ctx, token) fmt.Printf(" 耗时: %v\n", time.Since(start)) // 3. FetchAvailableModels fmt.Println(" 3. FetchAvailableModels...") start = time.Now() _, _, err3 := client.FetchAvailableModels(ctx, token, projectID) elapsed := time.Since(start) fmt.Printf(" 耗时: %v\n", elapsed) success := err1 == nil && err2 == nil && err3 == nil return success, fmt.Sprintf("账号初始化预热 - 状态: %v", map[bool]string{true: "✓", false: "✗"}[success]) } func main() { accessToken := flag.String("token", "", "OAuth access token") projectID := flag.String("project", "", "Project ID") refreshToken := flag.String("refresh", "", "Refresh token (optional)") testName := flag.String("test", "all", "测试名称 (all, single_request, sequential_requests, etc.)") flag.Parse() if *accessToken == "" || *projectID == "" { log.Fatal("缺少必需参数: -token 和 -project") } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() fmt.Println("\n" + repeatStr("=", 80)) fmt.Println("Antigravity 账号初始化诊断测试套件") fmt.Println(repeatStr("=", 80) + "\n") // Token 刷新测试 if *refreshToken != "" { fmt.Println("[Token 刷新测试]") _, report := testOAuthTokenRefresh(ctx, *refreshToken) fmt.Printf("%s\n\n", report) } // 账号初始化预热测试 fmt.Println("[账号初始化预热]") _, report := testAccountInitializationWarmup(ctx, *accessToken, *projectID) fmt.Printf("%s\n\n", report) // 运行指定的测试 if *testName == "all" { for _, scenario := range scenarios { fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description) _, report := scenario.testFunc(ctx, *accessToken, *projectID) fmt.Printf("结果: %s\n\n", report) } } else { found := false for _, scenario := range scenarios { if scenario.name == *testName { found = true fmt.Printf("[%s]\n%s\n", scenario.name, scenario.description) _, report := scenario.testFunc(ctx, *accessToken, *projectID) fmt.Printf("结果: %s\n\n", report) break } } if !found { log.Fatalf("未找到测试: %s", *testName) } } fmt.Println(repeatStr("=", 80)) fmt.Println("诊断完成") fmt.Println(repeatStr("=", 80)) } func repeatStr(s string, count int) string { result := "" for i := 0; i < count; i++ { result += s } return result }