// HTTP client for Windsurf upstream JSON/Connect-RPC endpoints. // Portions derived from windsurf-tools (MIT 2025 shaoyu521). See ./LICENSE. package windsurf import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" ) // Client wraps an *http.Client and the Windsurf base URL. type Client struct { BaseURL string HTTP *http.Client CSRFToken string } // NewClient builds a Client. proxyURL may be empty. func NewClient(baseURL, proxyURL string, csrfToken ...string) (*Client, error) { if baseURL == "" { baseURL = DefaultBaseURL } transport := &http.Transport{ ForceAttemptHTTP2: true, IdleConnTimeout: 90 * time.Second, ResponseHeaderTimeout: 60 * time.Second, ExpectContinueTimeout: 1 * time.Second, } if proxyURL != "" { u, err := url.Parse(proxyURL) if err != nil { return nil, fmt.Errorf("parse proxy: %w", err) } transport.Proxy = http.ProxyURL(u) } var csrf string if len(csrfToken) > 0 { csrf = csrfToken[0] } return &Client{ BaseURL: baseURL, CSRFToken: csrf, HTTP: &http.Client{ Transport: transport, Timeout: 180 * time.Second, }, }, nil } // CheckChatCapacity returns hasCapacity flag from server. func (c *Client) CheckChatCapacity(ctx context.Context, token string) (bool, string, error) { rawJWT := StripDevinPrefix(token) body := map[string]any{ "metadata": map[string]any{ "apiKey": token, "ideName": AppName, "ideVersion": AppVersion, "extensionName": AppName, "extensionVersion": "0.2.0", "sessionId": generateUUID(), "requestId": randomUint64String(), }, } resp, err := c.unaryJSON(ctx, "/exa.api_server_pb.ApiServerService/CheckChatCapacity", body, rawJWT) if err != nil { return false, "", err } var out struct { HasCapacity bool `json:"hasCapacity"` } if err := json.Unmarshal(resp, &out); err != nil { return false, string(resp), fmt.Errorf("decode: %w", err) } return out.HasCapacity, string(resp), nil } // UserStatus holds the fields from GetUserStatus. type UserStatus struct { UserID string `json:"userId"` TeamID string `json:"teamId"` Name string `json:"name"` Email string `json:"email"` PlanName string `json:"planName,omitempty"` DailyPercent *float64 `json:"dailyPercent,omitempty"` WeeklyPercent *float64 `json:"weeklyPercent,omitempty"` MonthlyPromptCredits *float64 `json:"monthlyPromptCredits,omitempty"` UsedPromptCredits *float64 `json:"usedPromptCredits,omitempty"` MonthlyFlexCredits *float64 `json:"monthlyFlexCredits,omitempty"` UsedFlexCredits *float64 `json:"usedFlexCredits,omitempty"` } // GetUserStatus fetches the user's plan status from server.codeium.com. func (c *Client) GetUserStatus(ctx context.Context, token string) (*UserStatus, error) { rawJWT := StripDevinPrefix(token) body := map[string]any{ "metadata": map[string]any{ "apiKey": token, "ideName": AppName, "ideVersion": AppVersion, "extensionName": AppName, "extensionVersion": "0.2.0", "sessionId": generateUUID(), "requestId": randomUint64String(), }, } resp, err := c.unaryJSONURL(ctx, "https://server.codeium.com/exa.api_server_pb.ApiServerService/GetUserStatus", body, rawJWT) if err != nil { return nil, err } var out struct { UserStatus struct { UserID string `json:"userId"` TeamID string `json:"teamId"` Name string `json:"name"` Email string `json:"email"` PlanStatus struct { PlanInfo struct { // 上游可能返回字符串(如 "Trial")或数字,统一用 json.RawMessage 兜底 // 再按需解析为字符串展示;避免 json.Number 遇字符串时解码失败导致整个 userStatus 拉取失败。 PlanName json.RawMessage `json:"planName"` MonthlyPromptCredits json.Number `json:"monthlyPromptCredits"` MonthlyFlexCredits json.Number `json:"monthlyFlexCreditPurchaseAmount"` } `json:"planInfo"` DailyQuotaRemainingPercent *float64 `json:"dailyQuotaRemainingPercent"` WeeklyQuotaRemainingPercent *float64 `json:"weeklyQuotaRemainingPercent"` UsedPromptCredits json.Number `json:"usedPromptCredits"` UsedFlexCredits json.Number `json:"usedFlexCredits"` } `json:"planStatus"` } `json:"userStatus"` } if err := json.Unmarshal(resp, &out); err != nil { return nil, fmt.Errorf("decode: %w (body=%s)", err, truncate(string(resp), 300)) } us := out.UserStatus ps := us.PlanStatus numPtr := func(n json.Number) *float64 { if n.String() == "" { return nil } v, err := n.Float64() if err != nil { return nil } // Legacy values come in hundredths v /= 100 return &v } return &UserStatus{ UserID: us.UserID, TeamID: us.TeamID, Name: us.Name, Email: us.Email, PlanName: planNameString(ps.PlanInfo.PlanName), DailyPercent: ps.DailyQuotaRemainingPercent, WeeklyPercent: ps.WeeklyQuotaRemainingPercent, MonthlyPromptCredits: numPtr(ps.PlanInfo.MonthlyPromptCredits), UsedPromptCredits: numPtr(ps.UsedPromptCredits), MonthlyFlexCredits: numPtr(ps.PlanInfo.MonthlyFlexCredits), UsedFlexCredits: numPtr(ps.UsedFlexCredits), }, nil } // planNameString 把上游 planName 字段(可能是字符串也可能是数字)统一还原为字符串。 func planNameString(raw json.RawMessage) string { if len(raw) == 0 { return "" } var s string if err := json.Unmarshal(raw, &s); err == nil { return s } return strings.Trim(string(raw), "\"") } // ModelInfo is one entry of GetCascadeModelConfigs response. type ModelInfo struct { ModelUID string `json:"modelUid"` Label string `json:"label"` CreditMultiplier float64 `json:"creditMultiplier"` IsRecommended bool `json:"isRecommended"` IsNew bool `json:"isNew"` } // ListModels returns the cascade model catalog. func (c *Client) ListModels(ctx context.Context, token string) ([]ModelInfo, error) { rawJWT := StripDevinPrefix(token) body := map[string]any{ "metadata": map[string]any{ "apiKey": token, "ideName": AppName, "ideVersion": AppVersion, "extensionName": AppName, "extensionVersion": "0.2.0", "sessionId": generateUUID(), "requestId": randomUint64String(), }, } resp, err := c.unaryJSON(ctx, "/exa.api_server_pb.ApiServerService/GetCascadeModelConfigs", body, rawJWT) if err != nil { return nil, err } var out struct { ClientModelConfigs []ModelInfo `json:"clientModelConfigs"` } if err := json.Unmarshal(resp, &out); err != nil { return nil, fmt.Errorf("decode: %w (body=%s)", err, truncate(string(resp), 300)) } return out.ClientModelConfigs, nil } // HasModel reports whether models contains the given uid. func HasModel(models []ModelInfo, uid string) bool { for _, m := range models { if strings.EqualFold(m.ModelUID, uid) { return true } } return false } func (c *Client) unaryJSON(ctx context.Context, path string, body any, rawJWT string) ([]byte, error) { return c.unaryJSONURL(ctx, c.BaseURL+path, body, rawJWT) } func (c *Client) unaryJSONURL(ctx context.Context, fullURL string, body any, rawJWT string) ([]byte, error) { jsonBody, err := json.Marshal(body) if err != nil { return nil, err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(jsonBody)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Connect-Protocol-Version", "1") req.Header.Set("User-Agent", UserAgent) if rawJWT != "" { req.Header.Set("Authorization", "Bearer "+rawJWT) } resp, err := c.HTTP.Do(req) if err != nil { return nil, err } defer resp.Body.Close() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode >= 400 { return respBody, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 300)) } return respBody, nil } func randomUint64String() string { var b [8]byte _, _ = readRandom(b[:]) var v uint64 for _, x := range b { v = (v << 8) | uint64(x) } v &^= 1 << 63 return fmt.Sprintf("%d", v) } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] + "...(truncated)" }