diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 16aff9f8..e318d1cd 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -254,6 +254,8 @@ const ( proxyTLSHandshakeTimeout = 5 * time.Second // clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body) clientTimeout = 10 * time.Second + // fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use. + fetchAvailableModelsBodyLimit int64 = 8 << 20 ) func NewClient(proxyURL string) (*Client, error) { @@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct { // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON // 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { + if c == nil || c.httpClient == nil { + return nil, nil, errors.New("antigravity client is not configured") + } + reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) if err != nil { @@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI // 固定顺序:prod -> daily availableURLs := BaseURLs + fetchClient := c.fetchAvailableModelsHTTPClient() var lastErr error for urlIdx, baseURL := range availableURLs { apiURL := baseURL + "/v1internal:fetchAvailableModels" @@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", GetUserAgentForContext(ctx)) - resp, err := c.httpClient.Do(req) + resp, err := fetchClient.Do(req) if err != nil { lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { @@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, lastErr } - respBodyBytes, err := io.ReadAll(resp.Body) + respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1)) _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 if err != nil { return nil, nil, fmt.Errorf("读取响应失败: %w", err) } + if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit { + return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit) + } // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { @@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, lastErr } +func (c *Client) fetchAvailableModelsHTTPClient() *http.Client { + fetchClient := *c.httpClient + fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect + return &fetchClient +} + +func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + if req == nil || req.URL == nil { + return errors.New("redirect url is nil") + } + if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) { + return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname()) + } + return nil +} + +func isAllowedFetchAvailableModelsRedirectHost(host string) bool { + host = strings.ToLower(strings.TrimSpace(host)) + if host == "" { + return false + } + for _, baseURL := range BaseURLs { + parsed, err := url.Parse(baseURL) + if err != nil { + continue + } + if strings.EqualFold(host, parsed.Hostname()) { + return true + } + } + return false +} + // ── Privacy API ────────────────────────────────────────────────────── // privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)