Merge pull request #2451 from wucm667/codex/issue-2237-gemini-chat-completions

fix(gateway): 修复 Gemini 组 Chat Completions 路由
This commit is contained in:
Wesley Liddick 2026-05-19 14:47:52 +08:00 committed by GitHub
commit 8a4ee578cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1198 additions and 5 deletions

View File

@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -950,8 +951,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
// Get available models from account configurations for the selected group platform.
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
if len(availableModels) > 0 {
// Build model list from whitelist
@ -972,7 +973,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
// Fallback to default models
if platform == "openai" {
if platform == service.PlatformOpenAI {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
@ -980,6 +981,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
if platform == service.PlatformGemini {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": geminicli.DefaultModels,
})
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,

View File

@ -161,12 +161,23 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
groupPlatform := ""
if apiKey.Group != nil {
groupPlatform = apiKey.Group.Platform
}
selectionSessionHash := sessionHash
if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
selectionSessionHash = "gemini:" + selectionSessionHash
}
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
if groupPlatform == service.PlatformGemini {
fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
@ -215,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
fs.FailedAccountIDs[account.ID] = struct{}{}
continue
}
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
var result *service.ForwardResult
if account.Platform == service.PlatformGemini {
if h.geminiCompatService == nil {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
if accountReleaseFunc != nil {
accountReleaseFunc()
}
return
}
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
} else {
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()

View File

@ -0,0 +1,136 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type gatewayModelsAccountRepoStub struct {
service.AccountRepository
byGroup map[int64][]service.Account
}
type gatewayModelsResponseForTest struct {
Object string `json:"object"`
Data []gatewayModelItemForTest `json:"data"`
}
type gatewayModelItemForTest struct {
ID string `json:"id"`
}
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, ok := s.byGroup[groupID]
if !ok {
return nil, nil
}
out := make([]service.Account, len(accounts))
copy(out, accounts)
return out, nil
}
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
return &GatewayHandler{
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(20)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformGemini},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, "list", got.Object)
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
}
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(21)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
},
{
ID: 2,
Platform: service.PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-flash": "gemini-2.5-flash",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
}
func modelIDsForTest(models []gatewayModelItemForTest) []string {
ids := make([]string, 0, len(models))
for _, model := range models {
ids = append(ids, model.ID)
}
return ids
}

View File

@ -0,0 +1,888 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
// ForwardAsChatCompletions serves OpenAI Chat Completions clients through
// Gemini accounts. It keeps the client-facing response in Chat Completions
// format while routing the upstream call through Gemini native endpoints.
func (s *GeminiMessagesCompatService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
) (*ForwardResult, error) {
startTime := time.Now()
var ccReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &ccReq); err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
}
if strings.TrimSpace(ccReq.Model) == "" {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
}
originalModel := ccReq.Model
clientStream := ccReq.Stream
includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
anthropicReq.Stream = clientStream
claudeBody, err := json.Marshal(anthropicReq)
if err != nil {
return nil, fmt.Errorf("marshal chat completions compat request: %w", err)
}
return s.forwardClaudeBodyAsChatCompletions(ctx, c, account, claudeBody, originalModel, clientStream, includeUsage, startTime, body)
}
func (s *GeminiMessagesCompatService) forwardClaudeBodyAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
claudeBody []byte,
originalModel string,
clientStream bool,
includeUsage bool,
startTime time.Time,
originalChatBody []byte,
) (*ForwardResult, error) {
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(claudeBody, &req); err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
}
if strings.TrimSpace(req.Model) == "" {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
}
mappedModel := req.Model
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model)
}
geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(claudeBody)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
}
geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
useUpstreamStream := clientStream
if account.Type == AccountTypeOAuth && !clientStream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
useUpstreamStream = true
}
buildReq, requestIDHeader := s.buildGeminiChatCompletionsUpstreamRequestFunc(
account,
mappedModel,
geminiReq,
clientStream,
useUpstreamStream,
)
var resp *http.Response
for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
upstreamReq, idHeader, err := buildReq(ctx)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", err.Error())
}
requestIDHeader = idHeader
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(geminiReq))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if attempt < geminiMaxRetries {
logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
}
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if resp.StatusCode == http.StatusForbidden && isGeminiInsufficientScope(resp.Header, respBody) {
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
if resp.StatusCode == http.StatusTooManyRequests {
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
Message: upstreamMsg,
})
logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get(requestIDHeader)
if requestID == "" {
requestID = resp.Header.Get("x-goog-request-id")
}
if requestID != "" {
c.Header("x-request-id", requestID)
}
reasoningEffort := extractCCReasoningEffortFromBody(originalChatBody)
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
evBody := unwrapIfNeeded(account.Type == AccountTypeOAuth, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
}
return nil, s.writeGeminiChatCompletionsMappedError(c, account, resp.StatusCode, requestID, evBody)
}
var usage *ClaudeUsage
var firstTokenMs *int
if clientStream {
streamRes, err := s.handleChatCompletionsStreamingResponseFromGemini(c, resp, startTime, originalModel, account.Type == AccountTypeOAuth, includeUsage)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else if useUpstreamStream {
collected, usageObj, err := collectGeminiSSE(resp.Body, account.Type == AccountTypeOAuth)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
collectedBytes, _ := json.Marshal(collected)
chatResp, usageObj2, err := geminiResponseToChatCompletions(collected, originalModel, collectedBytes, usageObj)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
c.JSON(http.StatusOK, chatResp)
usage = usageObj2
} else {
usageResp, err := s.handleChatCompletionsNonStreamingResponseFromGemini(c, resp, originalModel, account.Type == AccountTypeOAuth)
if err != nil {
return nil, err
}
usage = usageResp
}
if usage == nil {
usage = &ClaudeUsage{}
}
imageCount := 0
imageSize := s.extractImageSize(claudeBody)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ReasoningEffort: reasoningEffort,
ImageCount: imageCount,
ImageSize: imageSize,
ClientDisconnect: false,
}, nil
}
func (s *GeminiMessagesCompatService) buildGeminiChatCompletionsUpstreamRequestFunc(
account *Account,
mappedModel string,
geminiReq []byte,
clientStream bool,
useUpstreamStream bool,
) (func(context.Context) (*http.Request, string, error), string) {
switch account.Type {
case AccountTypeAPIKey:
return func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, "", errors.New("gemini api_key not configured")
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
action := "generateContent"
if clientStream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if clientStream {
fullURL += "?alt=sse"
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("x-goog-api-key", apiKey)
return upstreamReq, "x-request-id", nil
}, "x-request-id"
case AccountTypeOAuth:
return func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
projectID := strings.TrimSpace(account.GetCredential("project_id"))
action := "generateContent"
if useUpstreamStream {
action = "streamGenerateContent"
}
if projectID != "" {
baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
var inner any
if err := json.Unmarshal(geminiReq, &inner); err != nil {
return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
}
wrappedBytes, _ := json.Marshal(map[string]any{
"model": mappedModel,
"project": projectID,
"request": inner,
})
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
return upstreamReq, "x-request-id", nil
}
baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, "", err
}
fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
if useUpstreamStream {
fullURL += "?alt=sse"
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}, "x-request-id"
case AccountTypeServiceAccount:
return func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, "", err
}
action := "generateContent"
if clientStream {
action = "streamGenerateContent"
}
fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, clientStream)
if err != nil {
return nil, "", err
}
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
return upstreamReq, "x-request-id", nil
}, "x-request-id"
default:
return func(context.Context) (*http.Request, string, error) {
return nil, "", fmt.Errorf("unsupported account type: %s", account.Type)
}, "x-request-id"
}
}
func (s *GeminiMessagesCompatService) handleChatCompletionsNonStreamingResponseFromGemini(
c *gin.Context,
resp *http.Response,
originalModel string,
isOAuth bool,
) (*ClaudeUsage, error) {
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return nil, err
}
if isOAuth {
if unwrappedBody, uwErr := unwrapGeminiResponse(respBody); uwErr == nil {
respBody = unwrappedBody
}
}
var geminiResp map[string]any
if err := json.Unmarshal(respBody, &geminiResp); err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
chatResp, usage, err := geminiResponseToChatCompletions(geminiResp, originalModel, respBody, nil)
if err != nil {
return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.JSON(http.StatusOK, chatResp)
return usage, nil
}
func geminiResponseToChatCompletions(
geminiResp map[string]any,
originalModel string,
rawData []byte,
usageOverride *ClaudeUsage,
) (*apicompat.ChatCompletionsResponse, *ClaudeUsage, error) {
claudeRespMap, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, rawData)
if usageOverride != nil && (usageOverride.InputTokens > 0 || usageOverride.OutputTokens > 0 || usageOverride.CacheReadInputTokens > 0) {
usage = usageOverride
if usageMap, ok := claudeRespMap["usage"].(map[string]any); ok {
usageMap["input_tokens"] = usage.InputTokens
usageMap["output_tokens"] = usage.OutputTokens
usageMap["cache_read_input_tokens"] = usage.CacheReadInputTokens
}
}
claudeBytes, err := json.Marshal(claudeRespMap)
if err != nil {
return nil, nil, err
}
var anthropicResp apicompat.AnthropicResponse
if err := json.Unmarshal(claudeBytes, &anthropicResp); err != nil {
return nil, nil, err
}
responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp)
return apicompat.ResponsesToChatCompletions(responsesResp, originalModel), usage, nil
}
func (s *GeminiMessagesCompatService) handleChatCompletionsStreamingResponseFromGemini(
c *gin.Context,
resp *http.Response,
startTime time.Time,
originalModel string,
isOAuth bool,
includeUsage bool,
) (*geminiStreamResult, error) {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
anthState := apicompat.NewAnthropicEventToResponsesState()
anthState.Model = originalModel
ccState := apicompat.NewResponsesEventToChatState()
ccState.Model = originalModel
ccState.IncludeUsage = includeUsage
var usage ClaudeUsage
var firstTokenMs *int
firstChunk := true
writeChatChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
return false
}
if _, err := io.WriteString(c.Writer, sse); err != nil {
return true
}
return false
}
emitAnthropicEvent := func(evt *apicompat.AnthropicStreamEvent) bool {
responsesEvents := apicompat.AnthropicEventToResponsesEvents(evt, anthState)
for _, resEvt := range responsesEvents {
chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range chunks {
if disconnected := writeChatChunk(chunk); disconnected {
return true
}
}
}
flusher.Flush()
return false
}
messageID := "msg_" + randomHex(12)
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "message_start",
Message: &apicompat.AnthropicResponse{
ID: messageID,
Type: "message",
Role: "assistant",
Model: originalModel,
Content: []apicompat.AnthropicContentBlock{},
Usage: apicompat.AnthropicUsage{},
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
finishReason := ""
sawToolUse := false
nextBlockIndex := 0
openBlockIndex := -1
openBlockType := ""
seenText := ""
openToolIndex := -1
openToolName := ""
seenToolJSON := ""
closeOpenBlock := func() bool {
if openBlockIndex < 0 {
return false
}
disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"})
openBlockIndex = -1
openBlockType = ""
return disconnected
}
closeOpenTool := func() bool {
if openToolIndex < 0 {
return false
}
disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"})
openToolIndex = -1
openToolName = ""
seenToolJSON = ""
return disconnected
}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload != "" && payload != "[DONE]" {
rawBytes := []byte(payload)
if isOAuth {
if innerBytes, uwErr := unwrapGeminiResponse(rawBytes); uwErr == nil {
rawBytes = innerBytes
}
}
var geminiResp map[string]any
if err := json.Unmarshal(rawBytes, &geminiResp); err == nil {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
finishReason = fr
}
if u := extractGeminiUsage(rawBytes); u != nil {
usage = *u
}
for _, part := range extractGeminiParts(geminiResp) {
if text, ok := part["text"].(string); ok && text != "" {
if openToolIndex >= 0 {
if closeOpenTool() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
delta, newSeen := computeGeminiTextDelta(seenText, text)
seenText = newSeen
if delta == "" {
continue
}
if openBlockType != "text" {
if closeOpenBlock() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
idx := nextBlockIndex
nextBlockIndex++
openBlockIndex = idx
openBlockType = "text"
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx,
ContentBlock: &apicompat.AnthropicContentBlock{
Type: "text",
Text: "",
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "content_block_delta",
Delta: &apicompat.AnthropicDelta{
Type: "text_delta",
Text: delta,
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
continue
}
if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
name, _ := fc["name"].(string)
if strings.TrimSpace(name) == "" {
name = "tool"
}
if closeOpenBlock() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
if openToolIndex >= 0 && openToolName != name {
if closeOpenTool() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
if openToolIndex < 0 {
idx := nextBlockIndex
nextBlockIndex++
openToolIndex = idx
openToolName = name
sawToolUse = true
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "content_block_start",
Index: &idx,
ContentBlock: &apicompat.AnthropicContentBlock{
Type: "tool_use",
ID: "toolu_" + randomHex(8),
Name: name,
Input: json.RawMessage(`{}`),
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
argsJSONText := "{}"
switch v := fc["args"].(type) {
case nil:
case string:
if strings.TrimSpace(v) != "" {
argsJSONText = v
}
default:
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
argsJSONText = string(b)
}
}
delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
seenToolJSON = newSeen
if delta != "" {
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "content_block_delta",
Delta: &apicompat.AnthropicDelta{
Type: "input_json_delta",
PartialJSON: delta,
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
}
}
}
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("stream read error: %w", err)
}
}
if closeOpenBlock() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
if closeOpenTool() {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
if sawToolUse {
stopReason = "tool_use"
}
anthState.InputTokens = usage.InputTokens
anthState.CacheReadInputTokens = usage.CacheReadInputTokens
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
Type: "message_delta",
Delta: &apicompat.AnthropicDelta{
Type: "message_delta",
StopReason: stopReason,
},
Usage: &apicompat.AnthropicUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
},
}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "message_stop"}) {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) {
chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
for _, chunk := range chunks {
if disconnected := writeChatChunk(chunk); disconnected {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
}
for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) {
if disconnected := writeChatChunk(chunk); disconnected {
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
}
_, _ = io.WriteString(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
}
func (s *GeminiMessagesCompatService) writeGeminiChatCompletionsMappedError(
c *gin.Context,
account *Account,
upstreamStatus int,
upstreamRequestID string,
body []byte,
) error {
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(body)))
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, "")
if account != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
Message: upstreamMsg,
})
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformGemini,
upstreamStatus,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
return s.writeChatCompletionsError(c, status, errType, errMsg)
}
statusCode := http.StatusBadGateway
errType := "upstream_error"
errMsg := "Upstream request failed"
if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
if mapped.Type != "" {
errType = mapped.Type
}
if mapped.Message != "" {
errMsg = mapped.Message
}
if mapped.StatusCode > 0 {
statusCode = mapped.StatusCode
}
}
switch upstreamStatus {
case http.StatusBadRequest:
if statusCode == http.StatusBadGateway {
statusCode = http.StatusBadRequest
}
if errType == "upstream_error" {
errType = "invalid_request_error"
}
if errMsg == "Upstream request failed" {
errMsg = "Invalid request"
}
case http.StatusNotFound:
statusCode = http.StatusNotFound
if errType == "upstream_error" {
errType = "not_found_error"
}
if errMsg == "Upstream request failed" {
errMsg = "Resource not found"
}
case http.StatusTooManyRequests:
statusCode = http.StatusTooManyRequests
if errType == "upstream_error" {
errType = "rate_limit_error"
}
if errMsg == "Upstream request failed" {
errMsg = "Upstream rate limit exceeded, please retry later"
}
case 529:
statusCode = http.StatusServiceUnavailable
if errType == "upstream_error" {
errType = "overloaded_error"
}
if errMsg == "Upstream request failed" {
errMsg = "Upstream service overloaded, please retry later"
}
}
if upstreamMsg != "" && errMsg == "Upstream request failed" {
errMsg = upstreamMsg
}
return s.writeChatCompletionsError(c, statusCode, errType, errMsg)
}
func (s *GeminiMessagesCompatService) writeChatCompletionsError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
return fmt.Errorf("%s", message)
}

View File

@ -1,6 +1,7 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -41,6 +42,134 @@ func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL str
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
func TestGeminiForwardAsChatCompletions_OAuthRoutesToGeminiAndReturnsChatFormat(t *testing.T) {
gin.SetMode(gin.TestMode)
upstreamBody := `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello from gemini"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":3}}}` + "\n\n" +
"data: [DONE]\n\n"
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
},
}
svc := &GeminiMessagesCompatService{
tokenProvider: &GeminiTokenProvider{},
httpUpstream: httpStub,
cfg: &config.Config{},
}
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "ya29.test-token",
"project_id": "project-1",
},
Concurrency: 1,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"hi"}]}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, 7, result.Usage.InputTokens)
require.Equal(t, 3, result.Usage.OutputTokens)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/v1internal:streamGenerateContent?alt=sse")
require.Equal(t, "Bearer ya29.test-token", httpStub.lastReq.Header.Get("Authorization"))
require.Empty(t, httpStub.lastReq.Header.Get("x-api-key"))
require.Empty(t, httpStub.lastReq.Header.Get("anthropic-version"))
var sent map[string]any
sentBody, err := io.ReadAll(httpStub.lastReq.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(sentBody, &sent))
require.Equal(t, "gemini-2.5-flash", sent["model"])
require.Equal(t, "project-1", sent["project"])
require.Contains(t, fmt.Sprint(sent["request"]), "hi")
var got map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, "chat.completion", got["object"])
require.Equal(t, "gemini-2.5-flash", got["model"])
choices, ok := got["choices"].([]any)
require.True(t, ok)
require.NotEmpty(t, choices)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
message, ok := choice["message"].(map[string]any)
require.True(t, ok)
require.Equal(t, "assistant", message["role"])
require.Equal(t, "hello from gemini", message["content"])
usage, ok := got["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(7), usage["prompt_tokens"])
require.Equal(t, float64(3), usage["completion_tokens"])
require.Equal(t, float64(10), usage["total_tokens"])
}
func TestGeminiForwardAsChatCompletions_StreamsOpenAIChunksFromGeminiSSE(t *testing.T) {
gin.SetMode(gin.TestMode)
upstreamBody := `data: {"candidates":[{"content":{"parts":[{"text":"hel"}]}}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":1}}` + "\n\n" +
`data: {"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":2}}` + "\n\n" +
"data: [DONE]\n\n"
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
},
}
svc := &GeminiMessagesCompatService{
httpUpstream: httpStub,
cfg: &config.Config{},
}
account := &Account{
ID: 102,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "gemini-api-key",
},
Concurrency: 1,
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gemini-2.5-flash","stream":true,"stream_options":{"include_usage":true},"messages":[{"role":"user","content":"hi"}]}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, result.Stream)
require.Equal(t, 2, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse")
require.Equal(t, "gemini-api-key", httpStub.lastReq.Header.Get("x-goog-api-key"))
out := rec.Body.String()
require.Contains(t, out, `"object":"chat.completion.chunk"`)
require.Contains(t, out, `"role":"assistant"`)
require.Contains(t, out, `"content":"hel"`)
require.Contains(t, out, `"content":"lo"`)
require.Contains(t, out, `"usage":{"prompt_tokens":2,"completion_tokens":2,"total_tokens":4}`)
require.Contains(t, out, "data: [DONE]")
}
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {