sub2api/backend/internal/service/openai_gateway_responses_chat_fallback.go
Wesley Liddick bbe847ed3e
Merge pull request #2805 from StarryKira/feat/configurable-pool-retry-status-codes
feat(account): configurable pool-mode same-account retry status codes
2026-05-27 22:09:55 +08:00

428 lines
13 KiB
Go

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an
// upstream that only supports /v1/chat/completions.
func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
var responsesReq apicompat.ResponsesRequest
if err := json.Unmarshal(body, &responsesReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Failed to parse request body",
},
})
return nil, fmt.Errorf("parse responses request: %w", err)
}
originalModel := strings.TrimSpace(responsesReq.Model)
if originalModel == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "model is required",
},
})
return nil, fmt.Errorf("missing model in request")
}
clientStream := responsesReq.Stream
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
serviceTier := extractOpenAIServiceTierFromBody(body)
chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": err.Error(),
},
})
return nil, fmt.Errorf("convert responses to chat completions: %w", err)
}
billingModel := resolveOpenAIForwardModel(account, originalModel, "")
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
chatReq.Model = upstreamModel
if clientStream {
chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true}
}
chatBody, err := json.Marshal(chatReq)
if err != nil {
return nil, fmt.Errorf("marshal chat completions fallback request: %w", err)
}
chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody)
if err != nil {
var blocked *OpenAIFastBlockedError
if errors.As(err, &blocked) {
writeOpenAIFastPolicyBlockedResponse(c, blocked)
}
return nil, err
}
if serviceTier == nil {
serviceTier = extractOpenAIServiceTierFromBody(chatBody)
}
logger.L().Debug("openai responses: forwarding via raw chat completions",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
)
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
if customUA := account.GetOpenAIUserAgent(); customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleErrorResponse(ctx, resp, c, account, chatBody, billingModel)
}
if clientStream {
return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Failed to read upstream response",
},
})
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
var ccResp apicompat.ChatCompletionsResponse
if err := json.Unmarshal(respBody, &ccResp); err != nil {
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "api_error",
"message": "Failed to parse upstream response",
},
})
return nil, fmt.Errorf("parse chat completions response: %w", err)
}
responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel)
usage := OpenAIUsage{}
if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok {
usage = parsed
}
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, responsesResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
func (s *OpenAIGatewayService) streamChatCompletionsAsResponses(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
headersWritten := false
writeStreamHeaders := func() {
if headersWritten {
return
}
headersWritten = true
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel)
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
sawDone := false
writeEvents := func(events []apicompat.ResponsesStreamEvent) {
if clientDisconnected || len(events) == 0 {
return
}
writeStreamHeaders()
for _, event := range events {
sse, err := apicompat.ResponsesEventToSSE(event)
if err != nil {
logger.L().Warn("openai responses chat fallback: failed to marshal stream event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing",
zap.Error(err),
zap.String("request_id", requestID),
)
return
}
}
c.Writer.Flush()
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
for scanner.Scan() {
line := scanner.Text()
payload, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
payload = strings.TrimSpace(payload)
if payload == "" {
continue
}
if payload == "[DONE]" {
sawDone = true
break
}
if u := extractCCStreamUsage(payload); u != nil {
usage = *u
}
var chunk apicompat.ChatCompletionsChunk
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state))
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai responses chat fallback: stream read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, fmt.Errorf("stream usage incomplete: %w", err)
}
writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state))
if !clientDisconnected {
writeStreamHeaders()
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
}
if !clientDisconnected {
c.Writer.Flush()
}
}
if !sawDone {
logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel",
zap.String("request_id", requestID),
)
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool {
if chunk == nil {
return false
}
for _, choice := range chunk.Choices {
if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 {
return true
}
}
return false
}