// Package lspool provides an HTTPUpstream adapter that routes // streamGenerateContent requests through real Language Server instances. // // Flow: // // sub2api → LSPoolUpstream.Do() → StartCascade → SendUserCascadeMessage // → LS internally calls cloudcode-pa (with authentic TLS fingerprint) // → Poll GetCascadeTrajectory for incremental text // → Format as SSE and stream back to sub2api service layer // // The model is extracted from the original request body, not hardcoded. package lspool import ( "bytes" "context" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "io" "log/slog" "net/http" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" ) // Upstream is the interface matching service.HTTPUpstream type Upstream interface { Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) } // LSPoolUpstream wraps an existing HTTPUpstream and intercepts // streamGenerateContent requests to route them through the LS pool. type LSPoolUpstream struct { pool Backend fallback Upstream logger *slog.Logger sessionMu sync.Mutex sessions map[string]*cascadeSessionState } // NewLSPoolUpstream creates an LS pool upstream wrapper. func NewLSPoolUpstream(pool Backend, fallback Upstream) *LSPoolUpstream { return &LSPoolUpstream{ pool: pool, fallback: fallback, logger: slog.Default().With("component", "lspool-upstream"), sessions: make(map[string]*cascadeSessionState), } } const ( userNamespaceHeader = "X-Sub2API-User-Key" useAICreditsHeader = "X-Antigravity-Use-AI-Credits" availableCreditsHeader = "X-Antigravity-Available-Credits" minimumCreditAmountHeader = "X-Antigravity-Minimum-Credit-Amount" sessionStateTTL = 30 * time.Minute lsSendMessageTimeout = 20 * time.Second lsModelConfigTimeout = 20 * time.Second ) var ( errLSRouteDirect = errors.New("request should use direct upstream") errLSTranscriptDrift = errors.New("request transcript diverged from cached cascade session") errLSQuotaExhausted = errors.New("ls cascade returned quota exhausted") errLSModelMapPending = errors.New("model mapping not ready") errLSModelMapDenied = errors.New("model mapping unavailable") ) // IsLSQuotaExhaustedError reports whether err originated from an LS cascade // quota/capacity exhaustion signal. func IsLSQuotaExhaustedError(err error) bool { return errors.Is(err, errLSQuotaExhausted) } // LSQuotaExhaustedMessage extracts the original LS error message, if present. func LSQuotaExhaustedMessage(err error) string { if err == nil { return "" } msg := strings.TrimSpace(err.Error()) if msg == "" { return "" } prefix := errLSQuotaExhausted.Error() if msg == prefix { return "" } if strings.HasPrefix(msg, prefix+":") { return strings.TrimSpace(strings.TrimPrefix(msg, prefix+":")) } return msg } func isPermanentModelMappingError(err error) bool { if err == nil { return false } return strings.Contains(strings.ToLower(err.Error()), "unauthorized_client") } func modelMappingDeniedReason(err error) string { if err == nil { return "" } return truncate(strings.TrimSpace(err.Error()), 200) } type cascadeSessionState struct { CascadeID string SystemText string History []geminiConversationTurn UpdatedAt time.Time } type geminiEnvelope struct { Model string `json:"model"` Request json.RawMessage `json:"request"` } type geminiRequestPayload struct { Contents []geminiWireContent `json:"contents"` SystemInstruction *geminiWireContent `json:"systemInstruction,omitempty"` GenerationConfig *geminiWireGenerationConfig `json:"generationConfig,omitempty"` SessionID string `json:"sessionId,omitempty"` } type geminiWireGenerationConfig struct { ResponseModalities []string `json:"responseModalities,omitempty"` ImageConfig json.RawMessage `json:"imageConfig,omitempty"` } type geminiWireContent struct { Role string `json:"role"` Parts []geminiWirePart `json:"parts"` } type geminiWirePart struct { Text string `json:"text,omitempty"` Thought bool `json:"thought,omitempty"` ThoughtSignature string `json:"thoughtSignature,omitempty"` InlineData *geminiWireInlineData `json:"inlineData,omitempty"` FunctionCall map[string]any `json:"functionCall,omitempty"` FunctionResponse map[string]any `json:"functionResponse,omitempty"` } type geminiWireInlineData struct { MimeType string `json:"mimeType"` Data string `json:"data"` } type geminiParsedRequest struct { Model string SessionID string SystemText string Turns []geminiConversationTurn ResponseModalities []string HasImageConfig bool HasUnsupported bool } type geminiConversationTurn struct { Role string Parts []geminiConversationPart } type geminiConversationPart struct { Kind string Text string MimeType string Data string } type lsRouteDecision struct { UseLS bool Reason string } type lsRequestTrace struct { StartedAt time.Time AccountID int64 Model string SessionIDHash string Replica int CascadeID string NewSession bool InflightAtAcquire int64 TurnCount int GetOrCreateDuration time.Duration StartCascadeDuration time.Duration BuildInputDuration time.Duration SendMessageDuration time.Duration FirstPollLatency time.Duration FirstTextLatency time.Duration PollCount int } // Do routes streamGenerateContent through LS, everything else through fallback. func (u *LSPoolUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) if !isStreamGenerate(req.URL.Path) { return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) } body, err := snapshotRequestBody(req) if err != nil { return nil, fmt.Errorf("snapshot request body: %w", err) } if len(bytes.TrimSpace(body)) == 0 { return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) } resp, err := u.doViaLS(req, body, accountID, proxyURL) if err != nil { if shouldFallbackDirect(err) { u.logger.Warn("[LS-POOL] LS fell back to direct", "account", accountID, "err", err) req.Body = io.NopCloser(bytes.NewReader(body)) return u.fallback.Do(req, proxyURL, accountID, accountConcurrency) } return nil, err } return resp, nil } // DoWithTLS — LS handles its own TLS, so profile is ignored for LS requests. func (u *LSPoolUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { u.extractAndStripInternalHeaders(req, strconv.FormatInt(accountID, 10)) if !isStreamGenerate(req.URL.Path) { return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) } body, err := snapshotRequestBody(req) if err != nil { return nil, fmt.Errorf("snapshot request body: %w", err) } if len(bytes.TrimSpace(body)) == 0 { return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) } resp, err := u.doViaLS(req, body, accountID, proxyURL) if err != nil { if shouldFallbackDirect(err) { u.logger.Warn("[LS-POOL] LS fell back to direct+TLS", "account", accountID, "err", err) req.Body = io.NopCloser(bytes.NewReader(body)) return u.fallback.DoWithTLS(req, proxyURL, accountID, accountConcurrency, profile) } return nil, err } return resp, nil } func (u *LSPoolUpstream) doViaLS(req *http.Request, body []byte, accountID int64, proxyURL string) (*http.Response, error) { accountKey := strconv.FormatInt(accountID, 10) if CurrentLSStrategy() != LSStrategyJSParity { return u.forwardDirectWithKeepalive(req, body, accountKey, accountID, proxyURL) } parsed, err := parseGeminiRequest(body) if err != nil { return u.forwardDirect(req, body, proxyURL, accountID, "parse request failed") } decision := decideJSParityRoute(parsed, body) if !decision.UseLS { return u.forwardDirect(req, body, proxyURL, accountID, decision.Reason) } resp, err := u.forwardChatViaLS(req, body, parsed, accountKey, accountID, proxyURL) if err != nil { if shouldFallbackDirect(err) { return u.forwardDirect(req, body, proxyURL, accountID, err.Error()) } return nil, err } return resp, nil } func shouldFallbackDirect(err error) bool { return errors.Is(err, errLSRouteDirect) || errors.Is(err, errLSTranscriptDrift) || errors.Is(err, errLSModelMapDenied) } func (u *LSPoolUpstream) forwardDirectWithKeepalive(req *http.Request, body []byte, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { // Start/reuse LS instance — keeps heartbeat alive, authenticates with // cloudcode-pa, and refreshes model mapping. The LS process itself is NOT // used as a proxy; we forward the original HTTP request directly to // cloudcode-pa, bypassing Cascade entirely. This avoids the IDE agent // system prompt that Cascade injects. _, err := u.pool.GetOrCreate(accountKey, "", proxyURL) if err != nil { return nil, fmt.Errorf("get LS instance: %w", err) } u.logger.Info("[LS-POOL] Forwarding via direct HTTP (LS keepalive active)", "account", accountID, "path", req.URL.Path) return u.forwardDirect(req, body, proxyURL, accountID, "strategy=direct") } func (u *LSPoolUpstream) forwardDirect(req *http.Request, body []byte, proxyURL string, accountID int64, reason string) (*http.Response, error) { u.logger.Info("[LS-POOL] Forwarding via direct HTTP", "account", accountID, "path", req.URL.Path, "reason", reason) req.Header.Del(userNamespaceHeader) req.Body = io.NopCloser(bytes.NewReader(body)) return u.fallback.Do(req, proxyURL, accountID, 1) } func (u *LSPoolUpstream) extractAndStripInternalHeaders(req *http.Request, accountKey string) { if auth := req.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { accessToken := strings.TrimPrefix(auth, "Bearer ") refreshToken := req.Header.Get("X-Antigravity-Refresh-Token") var expiresAt time.Time if raw := req.Header.Get("X-Antigravity-Token-Expiry"); raw != "" { if parsed, err := time.Parse(time.RFC3339, raw); err == nil { expiresAt = parsed } } u.pool.SetAccountToken(accountKey, accessToken, refreshToken, expiresAt) } useAICredits, hasUseAICredits := parseBoolHeader(req.Header.Get(useAICreditsHeader)) availableCredits, hasAvailableCredits := parseOptionalInt32Header(req.Header.Get(availableCreditsHeader)) minimumCreditAmount, hasMinimumCreditAmount := parseOptionalInt32Header(req.Header.Get(minimumCreditAmountHeader)) if hasUseAICredits || hasAvailableCredits || hasMinimumCreditAmount { u.pool.SetAccountModelCredits(accountKey, useAICredits, availableCredits, minimumCreditAmount) } req.Header.Del("X-Antigravity-Refresh-Token") req.Header.Del("X-Antigravity-Token-Expiry") req.Header.Del(useAICreditsHeader) req.Header.Del(availableCreditsHeader) req.Header.Del(minimumCreditAmountHeader) } func parseBoolHeader(raw string) (bool, bool) { raw = strings.TrimSpace(raw) if raw == "" { return false, false } val, err := strconv.ParseBool(raw) if err != nil { return false, false } return val, true } func parseOptionalInt32Header(raw string) (*int32, bool) { raw = strings.TrimSpace(raw) if raw == "" { return nil, false } val, err := strconv.ParseInt(raw, 10, 32) if err != nil { return nil, false } parsed := int32(val) return &parsed, true } func shortTraceID(raw string) string { raw = strings.TrimSpace(raw) if raw == "" { return "none" } sum := sha256.Sum256([]byte(raw)) return fmt.Sprintf("%x", sum[:4]) } func durationMS(d time.Duration) int64 { if d <= 0 { return 0 } return d.Milliseconds() } func (u *LSPoolUpstream) logTraceSummary(level slog.Level, msg string, trace *lsRequestTrace, extra ...any) { if trace == nil { u.logger.Log(context.Background(), level, msg, extra...) return } args := []any{ "account", trace.AccountID, "model", trace.Model, "session", trace.SessionIDHash, "replica", trace.Replica, "cascade", shortTraceID(trace.CascadeID), "new_session", trace.NewSession, "turns", trace.TurnCount, "inflight", trace.InflightAtAcquire, "get_or_create_ms", durationMS(trace.GetOrCreateDuration), "start_cascade_ms", durationMS(trace.StartCascadeDuration), "build_input_ms", durationMS(trace.BuildInputDuration), "send_message_ms", durationMS(trace.SendMessageDuration), "first_poll_ms", durationMS(trace.FirstPollLatency), "first_token_ms", durationMS(trace.FirstTextLatency), "polls", trace.PollCount, "total_ms", durationMS(time.Since(trace.StartedAt)), } args = append(args, extra...) u.logger.Log(context.Background(), level, msg, args...) } func (u *LSPoolUpstream) forwardChatViaLS(req *http.Request, body []byte, parsed *geminiParsedRequest, accountKey string, accountID int64, proxyURL string) (*http.Response, error) { trace := &lsRequestTrace{ StartedAt: time.Now(), AccountID: accountID, Model: parsed.Model, SessionIDHash: shortTraceID(parsed.SessionID), TurnCount: len(parsed.Turns), } getOrCreateStartedAt := time.Now() sessionKey := buildSessionCacheKey(accountID, userNamespace(req), parsed.SessionID) inst, err := u.pool.GetOrCreate(accountKey, sessionKey, proxyURL) if err != nil { trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) u.logTraceSummary(slog.LevelWarn, "[LS-POOL] get instance failed", trace, "err", err) return nil, fmt.Errorf("get LS instance: %w", err) } trace.GetOrCreateDuration = time.Since(getOrCreateStartedAt) trace.Replica = inst.Replica if inst.HasModelMappingUnavailable() { reason := inst.ModelMappingUnavailableReason() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping unavailable, routing direct", trace, "reason", reason) return nil, fmt.Errorf("%w: %s", errLSModelMapDenied, reason) } if !inst.HasModelMappingReady() { u.logTraceSummary(slog.LevelInfo, "[LS-POOL] model mapping pending, routing direct", trace) return nil, errLSModelMapPending } if !inst.AcquireConcurrency() { u.logTraceSummary(slog.LevelWarn, "[LS-POOL] instance busy", trace, "err", fmt.Sprintf("ls instance busy for account %d", accountID), "current_inflight", inst.ConcurrentCount(), "max_inflight", maxConcurrencyPerInstance) return nil, fmt.Errorf("ls instance busy for account %d", accountID) } trace.InflightAtAcquire = inst.ConcurrentCount() state := u.getSessionState(sessionKey) if state != nil && !systemTextCompatible(state.SystemText, parsed.SystemText) { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript drift, routing direct", trace) return nil, errLSTranscriptDrift } cascadeID := "" newSession := false sendTurn := geminiConversationTurn{} contextPrefix := "" switch { case state == nil: if len(parsed.Turns) == 0 { inst.ReleaseConcurrency() return nil, errLSRouteDirect } lastTurn := parsed.Turns[len(parsed.Turns)-1] if lastTurn.Role != "user" { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] invalid first turn for LS, routing direct", trace) return nil, errLSRouteDirect } sendTurn = lastTurn contextPrefix = renderConversationContext(parsed.SystemText, parsed.Turns[:len(parsed.Turns)-1]) startCascadeStartedAt := time.Now() cascadeID, err = u.startCascade(inst) trace.StartCascadeDuration = time.Since(startCascadeStartedAt) if err != nil { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelWarn, "[LS-POOL] start cascade failed", trace, "err", err) return nil, err } newSession = true case !conversationPrefixEqual(parsed.Turns, state.History): inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] transcript prefix mismatch, routing direct", trace) return nil, errLSTranscriptDrift default: delta := parsed.Turns[len(state.History):] if len(delta) != 1 || delta[0].Role != "user" { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] unsupported transcript delta, routing direct", trace) return nil, errLSRouteDirect } sendTurn = delta[0] cascadeID = state.CascadeID } trace.NewSession = newSession trace.CascadeID = cascadeID buildInputStartedAt := time.Now() items, media, err := buildLSInputFromTurn(sendTurn, contextPrefix) trace.BuildInputDuration = time.Since(buildInputStartedAt) if err != nil { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelWarn, "[LS-POOL] build input failed", trace, "err", err) return nil, fmt.Errorf("build ls input: %w", err) } if len(items) == 0 && len(media) == 0 { inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelInfo, "[LS-POOL] empty LS input, routing direct", trace) return nil, errLSRouteDirect } sendReq := map[string]any{ "metadata": buildLSRequestMetadata(), "cascadeId": cascadeID, "items": items, "blocking": false, } if len(media) > 0 { sendReq["media"] = media } if cfg := buildCascadeConfig(parsed.Model); cfg != nil { sendReq["cascadeConfig"] = cfg } sendStartedAt := time.Now() sendCtx, sendCancel := context.WithTimeout(req.Context(), lsSendMessageTimeout) defer sendCancel() if _, err := inst.CallUnaryJSON(sendCtx, LSService, "SendUserCascadeMessage", sendReq); err != nil { trace.SendMessageDuration = time.Since(sendStartedAt) if newSession { u.cancelCascade(inst, cascadeID) } inst.ReleaseConcurrency() u.logTraceSummary(slog.LevelWarn, "[LS-POOL] send user message failed", trace, "err", err) return nil, fmt.Errorf("send user cascade message: %w", err) } trace.SendMessageDuration = time.Since(sendStartedAt) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{ "Content-Type": []string{"text/event-stream"}, "Cache-Control": []string{"no-cache"}, "X-Accel-Buffering": []string{"no"}, }, Body: pr, Request: req, } go func() { defer inst.ReleaseConcurrency() ctx, cancel := context.WithCancel(req.Context()) defer cancel() u.streamCascadeResponse(ctx, inst, cascadeID, pw, trace, func(finalText string) { u.putSessionState(sessionKey, &cascadeSessionState{ CascadeID: cascadeID, SystemText: parsed.SystemText, History: appendModelTurn(cloneConversationTurns(parsed.Turns), finalText), UpdatedAt: time.Now(), }) }) }() return resp, nil } func (u *LSPoolUpstream) startCascade(inst *Instance) (string, error) { resp, err := inst.CallUnaryJSON(context.Background(), LSService, "StartCascade", map[string]any{ "metadata": buildLSRequestMetadata(), }) if err != nil { return "", fmt.Errorf("start cascade: %w", err) } var decoded struct { CascadeID string `json:"cascadeId"` } if err := json.Unmarshal(resp, &decoded); err != nil { return "", fmt.Errorf("decode start cascade: %w", err) } if decoded.CascadeID == "" { return "", errors.New("start cascade returned empty cascadeId") } return decoded.CascadeID, nil } // cancelCascade tells the LS to stop processing a cascade invocation. // Uses a short timeout — best-effort, don't block shutdown. func (u *LSPoolUpstream) cancelCascade(inst *Instance, cascadeID string) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _, err := inst.CallUnaryJSON(ctx, LSService, "CancelCascadeInvocation", map[string]any{ "cascadeId": cascadeID, }) if err != nil { // Try force stop as fallback _, _ = inst.CallUnaryJSON(ctx, LSService, "ForceStopCascadeTree", map[string]any{ "cascadeId": cascadeID, }) } } // streamCascadeResponse polls GetCascadeTrajectory with adaptive interval. // Fast (50ms) when model is generating, slow (150ms) when waiting. // We also issue an immediate first poll so the first token is not delayed by // the initial ticker interval. func (u *LSPoolUpstream) streamCascadeResponse(ctx context.Context, inst *Instance, cascadeID string, w *io.PipeWriter, trace *lsRequestTrace, onDone func(string)) { const ( fastInterval = 50 * time.Millisecond slowInterval = 150 * time.Millisecond maxDuration = 5 * time.Minute maxIdleTimeout = 30 * time.Second ) ticker := time.NewTicker(slowInterval) defer ticker.Stop() timeout := time.After(maxDuration) lastText := "" generating := false lastProgressAt := time.Time{} pollOnce := func() bool { if trace != nil { trace.PollCount++ } trajResp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeTrajectory", map[string]any{ "cascadeId": cascadeID, }) if err != nil { if ctx.Err() != nil { u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) _ = w.Close() return true } return false } if trace != nil && trace.FirstPollLatency == 0 { trace.FirstPollLatency = time.Since(trace.StartedAt) } state := extractPlannerResponseState(trajResp) text, isGenerating, status := state.Text, state.Generating, state.Status if state.ErrorMessage != "" { u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade terminated with error", trace, "error", state.ErrorMessage) if isQuotaExhaustedError(state.ErrorMessage) { _ = w.CloseWithError(fmt.Errorf("%w: %s", errLSQuotaExhausted, state.ErrorMessage)) } else { _ = w.CloseWithError(errors.New(state.ErrorMessage)) } return true } // Adaptive interval: fast when generating, slow when idle. if isGenerating && !generating { ticker.Reset(fastInterval) generating = true } else if !isGenerating && generating { ticker.Reset(slowInterval) generating = false } // Emit new text as SSE. if text != lastText && len(text) > len(lastText) { newPart := text[len(lastText):] sseEvent := buildGeminiSSEChunk(newPart) if _, err := w.Write([]byte(sseEvent)); err != nil { u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write SSE failed", trace, "err", err) _ = w.CloseWithError(err) return true } lastText = text lastProgressAt = time.Now() if trace != nil && trace.FirstTextLatency == 0 { trace.FirstTextLatency = time.Since(trace.StartedAt) } } // Check if done. if status == "CASCADE_RUN_STATUS_IDLE" && text != "" && !isGenerating { usage := extractUsageFromTrajectory(trajResp) if usage != nil { finalEvent := buildGeminiSSEFinalChunk(usage) if _, err := w.Write([]byte(finalEvent)); err != nil { u.logTraceSummary(slog.LevelWarn, "[LS-POOL] write final SSE failed", trace, "err", err) _ = w.CloseWithError(err) return true } } if onDone != nil { onDone(lastText) } u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request completed", trace) _ = w.Close() return true } if !lastProgressAt.IsZero() && time.Since(lastProgressAt) > maxIdleTimeout { u.logTraceSummary(slog.LevelWarn, "[LS-POOL] No progress, stopping", trace) _ = w.Close() return true } return false } if pollOnce() { return } for { select { case <-ctx.Done(): u.logTraceSummary(slog.LevelInfo, "[LS-POOL] request context canceled", trace) _ = w.Close() return case <-timeout: u.logTraceSummary(slog.LevelWarn, "[LS-POOL] Cascade timeout", trace) _ = w.Close() return case <-ticker.C: if pollOnce() { return } } } } // ============================================================ // SSE builders — match Gemini v1internal:streamGenerateContent?alt=sse format // ============================================================ func buildGeminiSSEChunk(text string) string { // cloudcode-pa v1internal 格式: {"response": {"candidates": [...]}} chunk := map[string]any{ "response": map[string]any{ "candidates": []map[string]any{ { "content": map[string]any{ "parts": []map[string]string{{"text": text}}, "role": "model", }, }, }, }, } data, _ := json.Marshal(chunk) return "data: " + string(data) + "\n\n" } func buildGeminiSSEFinalChunk(usage map[string]any) string { chunk := map[string]any{ "response": map[string]any{ "candidates": []map[string]any{ { "content": map[string]any{ "parts": []map[string]string{{"text": ""}}, "role": "model", }, "finishReason": "STOP", }, }, "usageMetadata": usage, }, } data, _ := json.Marshal(chunk) return "data: " + string(data) + "\n\n" } // ============================================================ // Trajectory parsing // ============================================================ type cascadePlannerState struct { Text string Generating bool Status string ErrorMessage string } func extractPlannerResponseState(trajResp []byte) cascadePlannerState { var raw map[string]any if err := json.Unmarshal(trajResp, &raw); err != nil { return cascadePlannerState{} } state := cascadePlannerState{} state.Status, _ = raw["status"].(string) state.ErrorMessage = findCascadeErrorMessage(raw) traj, ok := raw["trajectory"].(map[string]any) if !ok { return state } steps, ok := traj["steps"].([]any) if !ok { return state } for _, s := range steps { sm, ok := s.(map[string]any) if !ok { continue } if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { continue } if sm["status"] == "CORTEX_STEP_STATUS_GENERATING" { state.Generating = true } if pr, ok := sm["plannerResponse"].(map[string]any); ok { if r, ok := pr["response"].(string); ok { state.Text = r } } } return state } func extractPlannerResponseText(trajResp []byte) (text string, generating bool, status string) { state := extractPlannerResponseState(trajResp) return state.Text, state.Generating, state.Status } func findCascadeErrorMessage(value any) string { switch v := value.(type) { case map[string]any: if msg := summarizeCascadeErrorMap(v); msg != "" { return msg } for _, child := range v { if msg := findCascadeErrorMessage(child); msg != "" { return msg } } case []any: for _, child := range v { if msg := findCascadeErrorMessage(child); msg != "" { return msg } } } return "" } func summarizeCascadeErrorMap(m map[string]any) string { if full := cascadeStringField(m, "fullError"); full != "" { return full } if user := cascadeStringField(m, "userErrorMessage"); user != "" { return user } _, hasErrorCode := m["errorCode"] short := cascadeStringField(m, "shortError") details := cascadeStringField(m, "details") message := cascadeStringField(m, "message") if hasErrorCode { parts := make([]string, 0, 3) if short != "" { parts = append(parts, short) } if message != "" && message != short { parts = append(parts, message) } if details != "" && details != short && details != message { parts = append(parts, details) } if len(parts) > 0 { return strings.Join(parts, ": ") } return fmt.Sprintf("cascade error: %v", m["errorCode"]) } if reason := cascadeStringField(m, "terminationReason"); strings.Contains(strings.ToUpper(reason), "ERROR") { if message != "" { return message } if short != "" { return short } } return "" } func cascadeStringField(m map[string]any, key string) string { raw, ok := m[key] if !ok { return "" } str, ok := raw.(string) if !ok { return "" } return strings.TrimSpace(str) } func extractUsageFromTrajectory(trajResp []byte) map[string]any { var raw map[string]any if err := json.Unmarshal(trajResp, &raw); err != nil { return nil } traj, ok := raw["trajectory"].(map[string]any) if !ok { return nil } steps, ok := traj["steps"].([]any) if !ok { return nil } for _, s := range steps { sm, ok := s.(map[string]any) if !ok { continue } if sm["type"] != "CORTEX_STEP_TYPE_PLANNER_RESPONSE" { continue } meta, ok := sm["metadata"].(map[string]any) if !ok { continue } mu, ok := meta["modelUsage"].(map[string]any) if !ok { continue } input, _ := mu["inputTokens"].(string) output, _ := mu["outputTokens"].(string) inputN, _ := strconv.Atoi(input) outputN, _ := strconv.Atoi(output) if inputN > 0 || outputN > 0 { return map[string]any{ "promptTokenCount": inputN, "candidatesTokenCount": outputN, "totalTokenCount": inputN + outputN, } } } return nil } // ============================================================ // Request parsing — dynamic model, no hardcoding // ============================================================ func parseGeminiRequest(body []byte) (*geminiParsedRequest, error) { var envelope geminiEnvelope if err := json.Unmarshal(body, &envelope); err != nil { return nil, err } reqBody := body if len(envelope.Request) > 0 { reqBody = envelope.Request } var payload geminiRequestPayload if err := json.Unmarshal(reqBody, &payload); err != nil { return nil, err } parsed := &geminiParsedRequest{ Model: envelope.Model, SessionID: payload.SessionID, ResponseModalities: append([]string(nil), payload.GenerationConfig.GetResponseModalities()...), HasImageConfig: payload.GenerationConfig != nil && len(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) > 0 && string(bytes.TrimSpace(payload.GenerationConfig.ImageConfig)) != "null", } if parsed.Model == "" { var top map[string]json.RawMessage if err := json.Unmarshal(body, &top); err == nil { _ = json.Unmarshal(top["model"], &parsed.Model) } } if payload.SystemInstruction != nil { parsed.SystemText = collectTextParts(payload.SystemInstruction.Parts) } for _, content := range payload.Contents { turn := geminiConversationTurn{Role: normalizeTurnRole(content.Role)} for _, part := range content.Parts { switch { case part.Thought || part.ThoughtSignature != "": parsed.HasUnsupported = true case len(part.FunctionCall) > 0 || len(part.FunctionResponse) > 0: parsed.HasUnsupported = true case part.InlineData != nil: turn.Parts = append(turn.Parts, geminiConversationPart{ Kind: "media", MimeType: part.InlineData.MimeType, Data: part.InlineData.Data, }) case part.Text != "": turn.Parts = append(turn.Parts, geminiConversationPart{ Kind: "text", Text: part.Text, }) } } if len(turn.Parts) > 0 { parsed.Turns = append(parsed.Turns, turn) } } return parsed, nil } func collectTextParts(parts []geminiWirePart) string { var texts []string for _, part := range parts { if part.Text != "" { texts = append(texts, part.Text) } } return strings.Join(texts, "\n") } func (g *geminiWireGenerationConfig) GetResponseModalities() []string { if g == nil { return nil } return g.ResponseModalities } func normalizeTurnRole(role string) string { if strings.EqualFold(strings.TrimSpace(role), "model") { return "model" } return "user" } func decideJSParityRoute(parsed *geminiParsedRequest, body []byte) lsRouteDecision { if parsed == nil { return lsRouteDecision{Reason: "nil parsed request"} } if requestHasTools(body) { return lsRouteDecision{Reason: "tools are not supported through cascade"} } if parsed.SessionID == "" { return lsRouteDecision{Reason: "missing sessionId"} } if parsed.HasUnsupported { return lsRouteDecision{Reason: "request contains unsupported Gemini parts"} } if isImageGenerationModelName(parsed.Model) { return lsRouteDecision{Reason: "image generation model"} } if parsed.HasImageConfig { return lsRouteDecision{Reason: "request has imageConfig"} } for _, modality := range parsed.ResponseModalities { if strings.EqualFold(strings.TrimSpace(modality), "IMAGE") { return lsRouteDecision{Reason: "responseModalities contains IMAGE"} } } if len(parsed.Turns) == 0 { return lsRouteDecision{Reason: "empty conversation"} } return lsRouteDecision{UseLS: true, Reason: "js-parity cascade chat"} } func extractPromptAndModel(body []byte) (string, string) { var outer map[string]json.RawMessage if err := json.Unmarshal(body, &outer); err != nil { return "", "" } var model string if m, ok := outer["model"]; ok { json.Unmarshal(m, &model) } if reqRaw, ok := outer["request"]; ok { return extractPromptFromGeminiRequest(reqRaw), model } return extractPromptFromGeminiRequest(body), model } func extractPromptFromGeminiRequest(data []byte) string { var req struct { Contents []struct { Parts []struct { Text string `json:"text"` } `json:"parts"` Role string `json:"role"` } `json:"contents"` SystemInstruction *struct { Parts []struct { Text string `json:"text"` } `json:"parts"` } `json:"systemInstruction"` } if err := json.Unmarshal(data, &req); err != nil { return "" } var parts []string // Include system instruction if present if req.SystemInstruction != nil { for _, p := range req.SystemInstruction.Parts { if p.Text != "" { parts = append(parts, "[System]\n"+p.Text) } } } // Include full conversation history for _, c := range req.Contents { role := c.Role if role == "" { role = "user" } for _, p := range c.Parts { if p.Text != "" { if role == "model" { parts = append(parts, "[Assistant]\n"+p.Text) } else { parts = append(parts, "[User]\n"+p.Text) } } } } if len(parts) == 0 { return "" } // If only one part and no system instruction, return raw text (simple case) if len(parts) == 1 && req.SystemInstruction == nil { text := parts[0] // Strip the [User]\n prefix for simple single-message case if strings.HasPrefix(text, "[User]\n") { return strings.TrimPrefix(text, "[User]\n") } return text } return strings.Join(parts, "\n\n") } func buildLSInputFromTurn(turn geminiConversationTurn, contextPrefix string) ([]map[string]any, []map[string]any, error) { items := make([]map[string]any, 0, len(turn.Parts)+1) media := make([]map[string]any, 0) if strings.TrimSpace(contextPrefix) != "" { items = append(items, map[string]any{"text": contextPrefix}) } for _, part := range turn.Parts { switch part.Kind { case "text": if part.Text != "" { items = append(items, map[string]any{"text": part.Text}) } case "media": decoded, err := base64.StdEncoding.DecodeString(part.Data) if err != nil { return nil, nil, fmt.Errorf("decode inlineData: %w", err) } media = append(media, map[string]any{ "mimeType": part.MimeType, "inlineData": decoded, }) } } return items, media, nil } func renderConversationContext(systemText string, turns []geminiConversationTurn) string { var parts []string if strings.TrimSpace(systemText) != "" { parts = append(parts, "[System]\n"+strings.TrimSpace(systemText)) } for _, turn := range turns { var rendered []string for _, part := range turn.Parts { switch part.Kind { case "text": if strings.TrimSpace(part.Text) != "" { rendered = append(rendered, part.Text) } case "media": label := "attachment" switch { case strings.HasPrefix(part.MimeType, "image/"): label = "image attachment" case strings.HasPrefix(part.MimeType, "video/"): label = "video attachment" case strings.HasPrefix(part.MimeType, "audio/"): label = "audio attachment" } rendered = append(rendered, fmt.Sprintf("[%s: %s]", label, part.MimeType)) } } if len(rendered) == 0 { continue } roleLabel := "User" if turn.Role == "model" { roleLabel = "Assistant" } parts = append(parts, fmt.Sprintf("[%s]\n%s", roleLabel, strings.Join(rendered, "\n"))) } return strings.Join(parts, "\n\n") } func buildCascadeConfig(model string) map[string]any { normalizedModel := normalizeRequestedModelName(model) if normalizedModel == "" { return nil } modelEnum := resolveModelEnum(normalizedModel) return map[string]any{ "plannerConfig": map[string]any{ "requestedModel": map[string]any{ "model": modelEnum, }, }, } } func buildLSRequestMetadata() map[string]any { return map[string]any{ "ideName": "antigravity", "ideVersion": "1.107.0", } } func appendModelTurn(turns []geminiConversationTurn, modelText string) []geminiConversationTurn { if strings.TrimSpace(modelText) == "" { return turns } return append(turns, geminiConversationTurn{ Role: "model", Parts: []geminiConversationPart{{ Kind: "text", Text: modelText, }}, }) } func cloneConversationTurns(src []geminiConversationTurn) []geminiConversationTurn { out := make([]geminiConversationTurn, 0, len(src)) for _, turn := range src { copied := geminiConversationTurn{ Role: turn.Role, Parts: append([]geminiConversationPart(nil), turn.Parts...), } out = append(out, copied) } return out } func conversationPrefixEqual(full, prefix []geminiConversationTurn) bool { if len(prefix) > len(full) { return false } for i := range prefix { if prefix[i].Role != full[i].Role { return false } if len(prefix[i].Parts) != len(full[i].Parts) { return false } for j := range prefix[i].Parts { if prefix[i].Parts[j] != full[i].Parts[j] { return false } } } return true } // ResolveModelEnumPublic is the exported version of resolveModelEnum for testing. func ResolveModelEnumPublic(model string) int { return resolveModelEnum(model) } // resolveModelEnum maps a Gemini/Claude model name to its proto enum number. // Priority: dynamic mapping (from LS) > static fallback. // The LS uses MODEL_PLACEHOLDER_Mn enum values (1000+n) that are dynamically // assigned by the server — only these are guaranteed to work. func resolveModelEnum(model string) int { model = normalizeRequestedModelName(model) // 1. Try dynamic mapping first (populated by RefreshModelMapping from LS) dynamicModelMapMu.RLock() // Exact match if v, ok := dynamicModelMap[model]; ok { dynamicModelMapMu.RUnlock() return v } // Fuzzy match: normalized label vs model name for label, v := range dynamicModelMap { if labelMatchesModel(label, model) { dynamicModelMapMu.RUnlock() return v } } // Prefix match in dynamic map for label, v := range dynamicModelMap { normalized := normalizeLabel(label) if strings.HasPrefix(model, normalized) || strings.HasPrefix(normalized, model) { dynamicModelMapMu.RUnlock() return v } } dynamicModelMapMu.RUnlock() // 2. Known working placeholders (verified on Mac with LS v1.107.0) // These map display labels to MODEL_PLACEHOLDER_Mn enum values knownPlaceholders := map[string]int{ "gemini-3-flash": 1047, "gemini-3.1-pro-high": 1037, "gemini-3.1-pro-low": 1036, "claude-sonnet-4-6-thinking": 1035, "claude-opus-4-6-thinking": 1026, "gpt-oss-120b-medium": 342, } if v, ok := knownPlaceholders[model]; ok { return v } // Fuzzy match known placeholders modelLower := strings.ToLower(model) for k, v := range knownPlaceholders { if strings.Contains(modelLower, strings.ToLower(k)) || strings.Contains(strings.ToLower(k), modelLower) { return v } } // 3. Family-based fallback from known placeholders for k, v := range knownPlaceholders { if strings.Contains(modelLower, "claude") && strings.Contains(k, "claude") { return v } if strings.Contains(modelLower, "gemini") && strings.Contains(k, "gemini") { return v } if strings.Contains(modelLower, "gpt") && strings.Contains(k, "gpt") { return v } } // 4. Also check dynamic map if available dynamicModelMapMu.RLock() defer dynamicModelMapMu.RUnlock() for label, v := range dynamicModelMap { labelLower := strings.ToLower(normalizeLabel(label)) // Same family: "claude" matches "claude-*", "gemini" matches "gemini-*" if strings.Contains(modelLower, "claude") && strings.Contains(labelLower, "claude") { return v } if strings.Contains(modelLower, "gemini") && strings.Contains(labelLower, "gemini") { return v } if strings.Contains(modelLower, "gpt") && strings.Contains(labelLower, "gpt") { return v } } // Last resort: return first available model from dynamic map for _, v := range dynamicModelMap { return v } // No dynamic mapping at all (LS not started yet?) — use gemini-2.5-flash static return 312 } // labelMatchesModel does fuzzy matching between LS display label and sub2api model name. // e.g. "Gemini 3 Flash" matches "gemini-3-flash", "Claude Sonnet 4.6 (Thinking)" matches "claude-sonnet-4-6-thinking" func labelMatchesModel(label, model string) bool { normalize := func(s string) string { s = strings.ToLower(s) s = strings.ReplaceAll(s, " ", "-") s = strings.ReplaceAll(s, ".", "-") s = strings.ReplaceAll(s, "(", "") s = strings.ReplaceAll(s, ")", "") s = strings.ReplaceAll(s, "--", "-") return strings.TrimRight(s, "-") } return normalize(label) == normalize(model) } // Dynamic model mapping — refreshed from LS at startup var ( dynamicModelMapMu sync.RWMutex dynamicModelMap = map[string]int{} // label -> enum value ) // HasDynamicModelMappingPublic is exported for testing. func HasDynamicModelMappingPublic() bool { return hasDynamicModelMapping() } // hasDynamicModelMapping returns true if at least one model has been loaded from the LS. func hasDynamicModelMapping() bool { dynamicModelMapMu.RLock() defer dynamicModelMapMu.RUnlock() return len(dynamicModelMap) > 0 } // RefreshModelMapping queries the LS for available models and builds the mapping. // Called automatically when an LS instance starts. func RefreshModelMapping(inst *Instance) bool { if inst == nil { return false } startedAt := time.Now() ctx, cancel := context.WithTimeout(context.Background(), lsModelConfigTimeout) defer cancel() resp, err := inst.CallUnaryJSON(ctx, LSService, "GetCascadeModelConfigData", map[string]any{}) if err != nil { inst.SetModelMappingReady(false) if isPermanentModelMappingError(err) { reason := modelMappingDeniedReason(err) inst.SetModelMappingUnavailable(reason) slog.Warn("[LS-POOL] Model mapping unavailable", "account", inst.AccountID, "replica", inst.Replica, "address", inst.Address, "elapsed", time.Since(startedAt).Truncate(time.Millisecond), "reason", reason) return false } inst.ClearModelMappingUnavailable() slog.Warn("[LS-POOL] Failed to get model config", "account", inst.AccountID, "replica", inst.Replica, "address", inst.Address, "elapsed", time.Since(startedAt).Truncate(time.Millisecond), "err", err) return false } var data struct { ClientModelConfigs []struct { Label string `json:"label"` ModelOrAlias map[string]any `json:"modelOrAlias"` } `json:"clientModelConfigs"` } if err := json.Unmarshal(resp, &data); err != nil { inst.SetModelMappingReady(false) inst.ClearModelMappingUnavailable() return false } newMap := make(map[string]int) for _, cfg := range data.ClientModelConfigs { label := cfg.Label if label == "" { continue } // modelOrAlias is {"model": "MODEL_PLACEHOLDER_M37"} in JSON modelStr, _ := cfg.ModelOrAlias["model"].(string) if modelStr == "" { continue } // Parse "MODEL_PLACEHOLDER_M37" → 1037 enumVal := parseModelEnumString(modelStr) if enumVal > 0 { // Store both the display label and a normalized form newMap[label] = enumVal // Also store kebab-case version: "Gemini 3 Flash" → "gemini-3-flash" normalized := normalizeLabel(label) if normalized != "" { newMap[normalized] = enumVal } } } if len(newMap) > 0 { dynamicModelMapMu.Lock() dynamicModelMap = newMap dynamicModelMapMu.Unlock() inst.SetModelMappingReady(true) inst.ClearModelMappingUnavailable() slog.Info("[LS-POOL] Model mapping refreshed", "account", inst.AccountID, "replica", inst.Replica, "address", inst.Address, "count", len(newMap)/2, "elapsed", time.Since(startedAt).Truncate(time.Millisecond)) return true } inst.SetModelMappingReady(false) inst.ClearModelMappingUnavailable() return false } func parseModelEnumString(s string) int { // Named enums named := map[string]int{ "MODEL_CLAUDE_4_SONNET": 281, "MODEL_CLAUDE_4_SONNET_THINKING": 282, "MODEL_CLAUDE_4_OPUS": 290, "MODEL_CLAUDE_4_OPUS_THINKING": 291, "MODEL_CLAUDE_4_5_SONNET": 333, "MODEL_CLAUDE_4_5_SONNET_THINKING": 334, "MODEL_CLAUDE_4_5_HAIKU": 340, "MODEL_CLAUDE_4_5_HAIKU_THINKING": 341, "MODEL_OPENAI_GPT_OSS_120B_MEDIUM": 342, "MODEL_GOOGLE_GEMINI_2_5_FLASH": 312, "MODEL_GOOGLE_GEMINI_2_5_FLASH_THINKING": 313, "MODEL_GOOGLE_GEMINI_2_5_FLASH_LITE": 330, "MODEL_GOOGLE_GEMINI_2_5_PRO": 246, } if v, ok := named[s]; ok { return v } // "MODEL_PLACEHOLDER_M37" → 1037 if strings.HasPrefix(s, "MODEL_PLACEHOLDER_M") { numStr := strings.TrimPrefix(s, "MODEL_PLACEHOLDER_M") n, err := strconv.Atoi(numStr) if err == nil { return 1000 + n } } return 0 } func normalizeLabel(label string) string { s := strings.ToLower(label) s = strings.ReplaceAll(s, " ", "-") s = strings.ReplaceAll(s, ".", "-") s = strings.ReplaceAll(s, "(", "") s = strings.ReplaceAll(s, ")", "") s = strings.ReplaceAll(s, "--", "-") return strings.TrimRight(s, "-") } func normalizeRequestedModelName(model string) string { normalized := strings.ToLower(strings.TrimSpace(model)) normalized = strings.TrimPrefix(normalized, "models/") return normalized } func isGeminiPlannerModel(model string) bool { return strings.Contains(normalizeRequestedModelName(model), "gemini") } func systemTextCompatible(stored, current string) bool { stored = strings.TrimSpace(stored) current = strings.TrimSpace(current) return current == "" || current == stored } // ============================================================ // Helpers // ============================================================ func buildSessionCacheKey(accountID int64, namespace, sessionID string) string { return fmt.Sprintf("%d:%s:%s", accountID, namespace, sessionID) } func userNamespace(req *http.Request) string { if req == nil { return "anon" } for _, value := range []string{ req.Header.Get(userNamespaceHeader), req.Header.Get("X-Api-Key"), req.Header.Get("X-Goog-Api-Key"), req.Header.Get("Authorization"), } { if strings.TrimSpace(value) != "" { sum := sha256.Sum256([]byte(value)) return fmt.Sprintf("%x", sum[:8]) } } return "anon" } func (u *LSPoolUpstream) getSessionState(key string) *cascadeSessionState { u.sessionMu.Lock() defer u.sessionMu.Unlock() u.pruneExpiredSessionsLocked() state := u.sessions[key] if state == nil { return nil } cloned := &cascadeSessionState{ CascadeID: state.CascadeID, SystemText: state.SystemText, History: cloneConversationTurns(state.History), UpdatedAt: state.UpdatedAt, } return cloned } func (u *LSPoolUpstream) putSessionState(key string, state *cascadeSessionState) { if state == nil { return } u.sessionMu.Lock() defer u.sessionMu.Unlock() u.pruneExpiredSessionsLocked() u.sessions[key] = &cascadeSessionState{ CascadeID: state.CascadeID, SystemText: state.SystemText, History: cloneConversationTurns(state.History), UpdatedAt: state.UpdatedAt, } } func (u *LSPoolUpstream) pruneExpiredSessionsLocked() { now := time.Now() for key, state := range u.sessions { if state == nil || now.Sub(state.UpdatedAt) > sessionStateTTL { delete(u.sessions, key) } } } func isStreamGenerate(path string) bool { return strings.Contains(path, "streamGenerateContent") } // isQuotaExhaustedError detects 429 QUOTA_EXHAUSTED errors from LS cascade trajectory. // When detected, the caller should fall back to direct HTTP so the gateway can // inject enabledCreditTypes for AI Credits retry. func isQuotaExhaustedError(msg string) bool { lower := strings.ToLower(msg) return (strings.Contains(lower, "resource_exhausted") || strings.Contains(lower, "quota_exhausted")) && (strings.Contains(lower, "429") || strings.Contains(lower, "exhausted your capacity")) } func isImageGenerationModelName(model string) bool { modelLower := normalizeRequestedModelName(model) return modelLower == "gemini-3.1-flash-image" || modelLower == "gemini-3.1-flash-image-preview" || strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") || modelLower == "gemini-3-pro-image" || modelLower == "gemini-3-pro-image-preview" || strings.HasPrefix(modelLower, "gemini-3-pro-image-") || modelLower == "gemini-2.5-flash-image" || modelLower == "gemini-2.5-flash-image-preview" || strings.HasPrefix(modelLower, "gemini-2.5-flash-image-") } // requestHasTools checks if the Gemini request body contains tools/function declarations. // These are not supported through the Cascade path and must use direct HTTP. func requestHasTools(body []byte) bool { // Check both the wrapped format {"request": {"tools": [...]}} and direct {"tools": [...]} var outer map[string]json.RawMessage if err := json.Unmarshal(body, &outer); err != nil { return false } // Check in wrapped request if reqRaw, ok := outer["request"]; ok { var inner map[string]json.RawMessage if json.Unmarshal(reqRaw, &inner) == nil { if tools, ok := inner["tools"]; ok && len(tools) > 4 { // > "[]" or "null" return true } } } // Check at top level if tools, ok := outer["tools"]; ok && len(tools) > 4 { return true } return false } func snapshotRequestBody(req *http.Request) ([]byte, error) { if req.Body == nil { return nil, nil } body, err := io.ReadAll(req.Body) if err != nil { return nil, err } req.Body.Close() req.Body = io.NopCloser(bytes.NewReader(body)) return body, nil } // unused but needed for compilation var _ sync.Mutex