From 041d138f76d020936a5193a3d7ab7e741f56c0c6 Mon Sep 17 00:00:00 2001 From: wucm667 Date: Thu, 14 May 2026 11:48:00 +0800 Subject: [PATCH] fix(gateway): route Gemini chat completions upstream --- .../gateway_handler_chat_completions.go | 35 +- .../gemini_chat_completions_compat_service.go | 888 ++++++++++++++++++ .../gemini_messages_compat_service_test.go | 123 +++ 3 files changed, 1044 insertions(+), 2 deletions(-) create mode 100644 backend/internal/service/gemini_chat_completions_compat_service.go diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index c6b73190..5b37995f 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -161,12 +161,23 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { APIKeyID: apiKey.ID, } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + groupPlatform := "" + if apiKey.Group != nil { + groupPlatform = apiKey.Group.Platform + } + selectionSessionHash := sessionHash + if groupPlatform == service.PlatformGemini && selectionSessionHash != "" { + selectionSessionHash = "gemini:" + selectionSessionHash + } // 3. Account selection + failover loop fs := NewFailoverState(h.maxAccountSwitches, false) + if groupPlatform == service.PlatformGemini { + fs = NewFailoverState(h.maxAccountSwitchesGemini, false) + } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) @@ -213,13 +224,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { } accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + fs.FailedAccountIDs[account.ID] = struct{}{} + continue + } + // 5. Forward request writerSizeBeforeForward := c.Writer.Size() forwardBody := body if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + var result *service.ForwardResult + if account.Platform == service.PlatformGemini { + if h.geminiCompatService == nil { + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured") + if accountReleaseFunc != nil { + accountReleaseFunc() + } + return + } + result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody) + } else { + result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + } if accountReleaseFunc != nil { accountReleaseFunc() diff --git a/backend/internal/service/gemini_chat_completions_compat_service.go b/backend/internal/service/gemini_chat_completions_compat_service.go new file mode 100644 index 00000000..595bb0bb --- /dev/null +++ b/backend/internal/service/gemini_chat_completions_compat_service.go @@ -0,0 +1,888 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" +) + +// ForwardAsChatCompletions serves OpenAI Chat Completions clients through +// Gemini accounts. It keeps the client-facing response in Chat Completions +// format while routing the upstream call through Gemini native endpoints. +func (s *GeminiMessagesCompatService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*ForwardResult, error) { + startTime := time.Now() + + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + } + if strings.TrimSpace(ccReq.Model) == "" { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + } + + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + anthropicReq.Stream = clientStream + + claudeBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions compat request: %w", err) + } + + return s.forwardClaudeBodyAsChatCompletions(ctx, c, account, claudeBody, originalModel, clientStream, includeUsage, startTime, body) +} + +func (s *GeminiMessagesCompatService) forwardClaudeBodyAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + claudeBody []byte, + originalModel string, + clientStream bool, + includeUsage bool, + startTime time.Time, + originalChatBody []byte, +) (*ForwardResult, error) { + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(claudeBody, &req); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + } + if strings.TrimSpace(req.Model) == "" { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + } + + mappedModel := req.Model + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { + mappedModel = account.GetMappedModel(req.Model) + } + + geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(claudeBody) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + useUpstreamStream := clientStream + if account.Type == AccountTypeOAuth && !clientStream && strings.TrimSpace(account.GetCredential("project_id")) != "" { + useUpstreamStream = true + } + + buildReq, requestIDHeader := s.buildGeminiChatCompletionsUpstreamRequestFunc( + account, + mappedModel, + geminiReq, + clientStream, + useUpstreamStream, + ) + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", err.Error()) + } + requestIDHeader = idHeader + + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(geminiReq)) + } + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if attempt < geminiMaxRetries { + logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + setOpsUpstreamError(c, 0, safeErr, "") + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr) + } + + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if resp.StatusCode == http.StatusForbidden && isGeminiInsufficientScope(resp.Header, respBody) { + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + if resp.StatusCode == http.StatusTooManyRequests { + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "retry", + Message: upstreamMsg, + }) + logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + reasoningEffort := extractCCReasoningEffortFromBody(originalChatBody) + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(account.Type == AccountTypeOAuth, respBody) + + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody))) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} + } + + return nil, s.writeGeminiChatCompletionsMappedError(c, account, resp.StatusCode, requestID, evBody) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if clientStream { + streamRes, err := s.handleChatCompletionsStreamingResponseFromGemini(c, resp, startTime, originalModel, account.Type == AccountTypeOAuth, includeUsage) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, account.Type == AccountTypeOAuth) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") + } + collectedBytes, _ := json.Marshal(collected) + chatResp, usageObj2, err := geminiResponseToChatCompletions(collected, originalModel, collectedBytes, usageObj) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + c.JSON(http.StatusOK, chatResp) + usage = usageObj2 + } else { + usageResp, err := s.handleChatCompletionsNonStreamingResponseFromGemini(c, resp, originalModel, account.Type == AccountTypeOAuth) + if err != nil { + return nil, err + } + usage = usageResp + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + imageCount := 0 + imageSize := s.extractImageSize(claudeBody) + if isImageGenerationModel(originalModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ReasoningEffort: reasoningEffort, + ImageCount: imageCount, + ImageSize: imageSize, + ClientDisconnect: false, + }, nil +} + +func (s *GeminiMessagesCompatService) buildGeminiChatCompletionsUpstreamRequestFunc( + account *Account, + mappedModel string, + geminiReq []byte, + clientStream bool, + useUpstreamStream bool, +) (func(context.Context) (*http.Request, string, error), string) { + switch account.Type { + case AccountTypeAPIKey: + return func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("gemini api_key not configured") + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if clientStream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if clientStream { + fullURL += "?alt=sse" + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + case AccountTypeOAuth: + return func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + action := "generateContent" + if useUpstreamStream { + action = "streamGenerateContent" + } + + if projectID != "" { + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrappedBytes, _ := json.Marshal(map[string]any{ + "model": mappedModel, + "project": projectID, + "request": inner, + }) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + case AccountTypeServiceAccount: + return func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if clientStream { + action = "streamGenerateContent" + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, clientStream) + if err != nil { + return nil, "", err + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + default: + return func(context.Context) (*http.Request, string, error) { + return nil, "", fmt.Errorf("unsupported account type: %s", account.Type) + }, "x-request-id" + } +} + +func (s *GeminiMessagesCompatService) handleChatCompletionsNonStreamingResponseFromGemini( + c *gin.Context, + resp *http.Response, + originalModel string, + isOAuth bool, +) (*ClaudeUsage, error) { + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return nil, err + } + if isOAuth { + if unwrappedBody, uwErr := unwrapGeminiResponse(respBody); uwErr == nil { + respBody = unwrappedBody + } + } + + var geminiResp map[string]any + if err := json.Unmarshal(respBody, &geminiResp); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + chatResp, usage, err := geminiResponseToChatCompletions(geminiResp, originalModel, respBody, nil) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.JSON(http.StatusOK, chatResp) + return usage, nil +} + +func geminiResponseToChatCompletions( + geminiResp map[string]any, + originalModel string, + rawData []byte, + usageOverride *ClaudeUsage, +) (*apicompat.ChatCompletionsResponse, *ClaudeUsage, error) { + claudeRespMap, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, rawData) + if usageOverride != nil && (usageOverride.InputTokens > 0 || usageOverride.OutputTokens > 0 || usageOverride.CacheReadInputTokens > 0) { + usage = usageOverride + if usageMap, ok := claudeRespMap["usage"].(map[string]any); ok { + usageMap["input_tokens"] = usage.InputTokens + usageMap["output_tokens"] = usage.OutputTokens + usageMap["cache_read_input_tokens"] = usage.CacheReadInputTokens + } + } + + claudeBytes, err := json.Marshal(claudeRespMap) + if err != nil { + return nil, nil, err + } + var anthropicResp apicompat.AnthropicResponse + if err := json.Unmarshal(claudeBytes, &anthropicResp); err != nil { + return nil, nil, err + } + responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp) + return apicompat.ResponsesToChatCompletions(responsesResp, originalModel), usage, nil +} + +func (s *GeminiMessagesCompatService) handleChatCompletionsStreamingResponseFromGemini( + c *gin.Context, + resp *http.Response, + startTime time.Time, + originalModel string, + isOAuth bool, + includeUsage bool, +) (*geminiStreamResult, error) { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + writeChatChunk := func(chunk apicompat.ChatCompletionsChunk) bool { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return false + } + if _, err := io.WriteString(c.Writer, sse); err != nil { + return true + } + return false + } + + emitAnthropicEvent := func(evt *apicompat.AnthropicStreamEvent) bool { + responsesEvents := apicompat.AnthropicEventToResponsesEvents(evt, anthState) + for _, resEvt := range responsesEvents { + chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range chunks { + if disconnected := writeChatChunk(chunk); disconnected { + return true + } + } + } + flusher.Flush() + return false + } + + messageID := "msg_" + randomHex(12) + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "message_start", + Message: &apicompat.AnthropicResponse{ + ID: messageID, + Type: "message", + Role: "assistant", + Model: originalModel, + Content: []apicompat.AnthropicContentBlock{}, + Usage: apicompat.AnthropicUsage{}, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + finishReason := "" + sawToolUse := false + nextBlockIndex := 0 + openBlockIndex := -1 + openBlockType := "" + seenText := "" + openToolIndex := -1 + openToolName := "" + seenToolJSON := "" + + closeOpenBlock := func() bool { + if openBlockIndex < 0 { + return false + } + disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"}) + openBlockIndex = -1 + openBlockType = "" + return disconnected + } + closeOpenTool := func() bool { + if openToolIndex < 0 { + return false + } + disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"}) + openToolIndex = -1 + openToolName = "" + seenToolJSON = "" + return disconnected + } + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if payload != "" && payload != "[DONE]" { + rawBytes := []byte(payload) + if isOAuth { + if innerBytes, uwErr := unwrapGeminiResponse(rawBytes); uwErr == nil { + rawBytes = innerBytes + } + } + + var geminiResp map[string]any + if err := json.Unmarshal(rawBytes, &geminiResp); err == nil { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if fr := extractGeminiFinishReason(geminiResp); fr != "" { + finishReason = fr + } + if u := extractGeminiUsage(rawBytes); u != nil { + usage = *u + } + + for _, part := range extractGeminiParts(geminiResp) { + if text, ok := part["text"].(string); ok && text != "" { + if openToolIndex >= 0 { + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + delta, newSeen := computeGeminiTextDelta(seenText, text) + seenText = newSeen + if delta == "" { + continue + } + if openBlockType != "text" { + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + idx := nextBlockIndex + nextBlockIndex++ + openBlockIndex = idx + openBlockType = "text" + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &apicompat.AnthropicContentBlock{ + Type: "text", + Text: "", + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "text_delta", + Text: delta, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + continue + } + + if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil { + name, _ := fc["name"].(string) + if strings.TrimSpace(name) == "" { + name = "tool" + } + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if openToolIndex >= 0 && openToolName != name { + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + if openToolIndex < 0 { + idx := nextBlockIndex + nextBlockIndex++ + openToolIndex = idx + openToolName = name + sawToolUse = true + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &apicompat.AnthropicContentBlock{ + Type: "tool_use", + ID: "toolu_" + randomHex(8), + Name: name, + Input: json.RawMessage(`{}`), + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + + argsJSONText := "{}" + switch v := fc["args"].(type) { + case nil: + case string: + if strings.TrimSpace(v) != "" { + argsJSONText = v + } + default: + if b, err := json.Marshal(v); err == nil && len(b) > 0 { + argsJSONText = string(b) + } + } + delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText) + seenToolJSON = newSeen + if delta != "" { + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "input_json_delta", + PartialJSON: delta, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + } + } + } + } + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("stream read error: %w", err) + } + } + + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) + if sawToolUse { + stopReason = "tool_use" + } + anthState.InputTokens = usage.InputTokens + anthState.CacheReadInputTokens = usage.CacheReadInputTokens + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "message_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "message_delta", + StopReason: stopReason, + }, + Usage: &apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "message_stop"}) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) { + chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range chunks { + if disconnected := writeChatChunk(chunk); disconnected { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + } + for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) { + if disconnected := writeChatChunk(chunk); disconnected { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + + _, _ = io.WriteString(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *GeminiMessagesCompatService) writeGeminiChatCompletionsMappedError( + c *gin.Context, + account *Account, + upstreamStatus int, + upstreamRequestID string, + body []byte, +) error { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(body))) + setOpsUpstreamError(c, upstreamStatus, upstreamMsg, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: upstreamStatus, + UpstreamRequestID: upstreamRequestID, + Kind: "http_error", + Message: upstreamMsg, + }) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + return s.writeChatCompletionsError(c, status, errType, errMsg) + } + + statusCode := http.StatusBadGateway + errType := "upstream_error" + errMsg := "Upstream request failed" + if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil { + if mapped.Type != "" { + errType = mapped.Type + } + if mapped.Message != "" { + errMsg = mapped.Message + } + if mapped.StatusCode > 0 { + statusCode = mapped.StatusCode + } + } + + switch upstreamStatus { + case http.StatusBadRequest: + if statusCode == http.StatusBadGateway { + statusCode = http.StatusBadRequest + } + if errType == "upstream_error" { + errType = "invalid_request_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Invalid request" + } + case http.StatusNotFound: + statusCode = http.StatusNotFound + if errType == "upstream_error" { + errType = "not_found_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Resource not found" + } + case http.StatusTooManyRequests: + statusCode = http.StatusTooManyRequests + if errType == "upstream_error" { + errType = "rate_limit_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Upstream rate limit exceeded, please retry later" + } + case 529: + statusCode = http.StatusServiceUnavailable + if errType == "upstream_error" { + errType = "overloaded_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Upstream service overloaded, please retry later" + } + } + + if upstreamMsg != "" && errMsg == "Upstream request failed" { + errMsg = upstreamMsg + } + return s.writeChatCompletionsError(c, statusCode, errType, errMsg) +} + +func (s *GeminiMessagesCompatService) writeChatCompletionsError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) + return fmt.Errorf("%s", message) +} diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index c2adf45d..e1b0ba3e 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "encoding/json" "fmt" @@ -41,6 +42,128 @@ func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL str return s.Do(req, proxyURL, accountID, accountConcurrency) } +func TestGeminiForwardAsChatCompletions_OAuthRoutesToGeminiAndReturnsChatFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamBody := `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello from gemini"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":3}}}` + "\n\n" + + "data: [DONE]\n\n" + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }, + } + svc := &GeminiMessagesCompatService{ + tokenProvider: &GeminiTokenProvider{}, + httpUpstream: httpStub, + cfg: &config.Config{}, + } + account := &Account{ + ID: 101, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ya29.test-token", + "project_id": "project-1", + }, + Concurrency: 1, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/v1internal:streamGenerateContent?alt=sse") + require.Equal(t, "Bearer ya29.test-token", httpStub.lastReq.Header.Get("Authorization")) + require.Empty(t, httpStub.lastReq.Header.Get("x-api-key")) + require.Empty(t, httpStub.lastReq.Header.Get("anthropic-version")) + + var sent map[string]any + sentBody, err := io.ReadAll(httpStub.lastReq.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(sentBody, &sent)) + require.Equal(t, "gemini-2.5-flash", sent["model"]) + require.Equal(t, "project-1", sent["project"]) + require.Contains(t, fmt.Sprint(sent["request"]), "hi") + + var got map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, "chat.completion", got["object"]) + require.Equal(t, "gemini-2.5-flash", got["model"]) + choices := got["choices"].([]any) + message := choices[0].(map[string]any)["message"].(map[string]any) + require.Equal(t, "assistant", message["role"]) + require.Equal(t, "hello from gemini", message["content"]) + usage := got["usage"].(map[string]any) + require.Equal(t, float64(7), usage["prompt_tokens"]) + require.Equal(t, float64(3), usage["completion_tokens"]) + require.Equal(t, float64(10), usage["total_tokens"]) +} + +func TestGeminiForwardAsChatCompletions_StreamsOpenAIChunksFromGeminiSSE(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamBody := `data: {"candidates":[{"content":{"parts":[{"text":"hel"}]}}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":1}}` + "\n\n" + + `data: {"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":2}}` + "\n\n" + + "data: [DONE]\n\n" + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }, + } + svc := &GeminiMessagesCompatService{ + httpUpstream: httpStub, + cfg: &config.Config{}, + } + account := &Account{ + ID: 102, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "gemini-api-key", + }, + Concurrency: 1, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gemini-2.5-flash","stream":true,"stream_options":{"include_usage":true},"messages":[{"role":"user","content":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, result.Stream) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse") + require.Equal(t, "gemini-api-key", httpStub.lastReq.Header.Get("x-goog-api-key")) + + out := rec.Body.String() + require.Contains(t, out, `"object":"chat.completion.chunk"`) + require.Contains(t, out, `"role":"assistant"`) + require.Contains(t, out, `"content":"hel"`) + require.Contains(t, out, `"content":"lo"`) + require.Contains(t, out, `"usage":{"prompt_tokens":2,"completion_tokens":2,"total_tokens":4}`) + require.Contains(t, out, "data: [DONE]") +} + // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { tests := []struct {