sub2api/backend/internal/service/openai_gateway_chat_completions_raw.go
Wesley Liddick bbe847ed3e
Merge pull request #2805 from StarryKira/feat/configurable-pool-retry-status-codes
feat(account): configurable pool-mode same-account retry status codes
2026-05-27 22:09:55 +08:00

490 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.uber.org/zap"
)
// openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。
//
// **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header
// originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id
// 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方
// OpenAI 兼容上游会造成:
// - 完全忽略(多数友好厂商)——隐性污染上游统计
// - 400 "unknown parameter"(严格上游)——可见错误
//
// 这里仅放行通用 HTTP headercontent-type / authorization / accept 由上下文
// 显式设置,不依赖透传。
//
// 参见决策记录:
// pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains
var openaiCCRawAllowedHeaders = map[string]bool{
"accept-language": true,
"user-agent": true,
}
// forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游
// `{base_url}/v1/chat/completions`**不**做 CC↔Responses 协议转换。
//
// 适用场景account.platform=openai && account.type=apikey && 上游已被探测确认
// 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。
//
// 与 ForwardAsChatCompletions 的关键差异:
//
// - 不调用 apicompat.ChatCompletionsToResponsesbody 仅做模型 ID 改写
// - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses
// - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式)
// - 非流式响应 JSON 直接透传,仅按需提取 usage
// - 不应用 codex OAuth transformAPIKey 路径无 OAuth
// - 不注入 prompt_cache_keyOAuth 专属机制)
//
// 调用入口openai_gateway_chat_completions.go::ForwardAsChatCompletions
// 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。
func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse minimal fields needed for routing/billing
originalModel := gjson.GetBytes(body, "model").String()
if originalModel == "" {
writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return nil, fmt.Errorf("missing model in request")
}
clientStream := gjson.GetBytes(body, "stream").Bool()
// 1b. Extract reasoning effort and service tier from the raw body before any transformation.
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel)
serviceTier := extractOpenAIServiceTierFromBody(body)
// 2. Resolve model mapping (same as ForwardAsChatCompletions)
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel)
// 3. Rewrite model in body (no protocol conversion)
upstreamBody := body
if upstreamModel != originalModel {
upstreamBody = ReplaceModelInBody(body, upstreamModel)
}
// 4. Apply OpenAI fast policy on the CC body
updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody)
if policyErr != nil {
var blocked *OpenAIFastBlockedError
if errors.As(policyErr, &blocked) {
MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied)
writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
}
return nil, policyErr
}
upstreamBody = updatedBody
if clientStream {
var usageErr error
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
if usageErr != nil {
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
}
}
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("billing_model", billingModel),
zap.String("upstream_model", upstreamModel),
zap.Bool("stream", clientStream),
)
// 5. Build upstream request
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return nil, fmt.Errorf("account %d missing api_key", account.ID)
}
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
baseURL = "https://api.openai.com"
}
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, fmt.Errorf("invalid base_url: %w", err)
}
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
releaseUpstreamCtx()
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI))
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
if clientStream {
upstreamReq.Header.Set("Accept", "text/event-stream")
} else {
upstreamReq.Header.Set("Accept", "application/json")
}
// 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiCCRawAllowedHeaders[lowerKey] {
for _, v := range values {
upstreamReq.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
upstreamReq.Header.Set("user-agent", customUA)
}
// 6. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 7. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel)
}
// 8. Forward response
if clientStream {
return s.streamRawChatCompletions(c, resp, account, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime, len(body))
}
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
// streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage包括
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
//
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage
// 让级联代理或下游计费系统也能拿到完整用量。
func (s *OpenAIGatewayService) streamRawChatCompletions(
c *gin.Context,
resp *http.Response,
account *Account,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
requestBodyLen int,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
headersWritten := false
writeStreamHeaders := func() {
if headersWritten {
return
}
headersWritten = true
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
clientOutputStarted := false
pendingLines := make([]string, 0, 8)
refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen)
writeLine := func(line string) {
if clientDisconnected {
return
}
if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() {
pendingLines = append(pendingLines, line)
return
}
if !clientOutputStarted {
writeStreamHeaders()
for _, pending := range pendingLines {
if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil {
clientDisconnected = true
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
zap.Error(werr),
zap.String("request_id", requestID),
)
return
}
}
pendingLines = pendingLines[:0]
clientOutputStarted = true
}
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
clientDisconnected = true
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
zap.Error(werr),
zap.String("request_id", requestID),
)
}
}
for scanner.Scan() {
line := scanner.Text()
refusalDetector.ObserveSSELine(line)
if payload, ok := extractOpenAISSEDataLine(line); ok {
trimmedPayload := strings.TrimSpace(payload)
if trimmedPayload != "[DONE]" {
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
if u := extractCCStreamUsage(payload); u != nil {
usage = *u
}
if firstTokenMs == nil && !usageOnlyChunk {
elapsed := int(time.Since(startTime).Milliseconds())
firstTokenMs = &elapsed
}
}
}
writeLine(line)
if line == "" {
if !clientDisconnected && clientOutputStarted {
c.Writer.Flush()
}
continue
}
if !clientDisconnected && clientOutputStarted {
c.Writer.Flush()
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions raw: stream read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
} else if !clientDisconnected && !clientOutputStarted {
if refusalDetector.IsSilentRefusal() {
return nil, newOpenAISilentRefusalFailoverError(c, account, requestID)
}
if len(pendingLines) > 0 {
writeStreamHeaders()
for _, pending := range pendingLines {
if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil {
clientDisconnected = true
logger.L().Debug("openai chat_completions raw: client disconnected during final flush",
zap.Error(werr),
zap.String("request_id", requestID),
)
break
}
}
if !clientDisconnected {
c.Writer.Flush()
clientOutputStarted = true
}
}
}
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
return body, err
}
return updated, nil
}
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
if strings.TrimSpace(payload) == "" {
return false
}
if !gjson.Get(payload, "usage").Exists() {
return false
}
choices := gjson.Get(payload, "choices")
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
}
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
// CC 协议中 usage 仅出现在末尾 chunk且仅当 include_usage 生效时),
// 但上游可能在多个 chunk 中重复——总是用最新值。
func extractCCStreamUsage(payload string) *OpenAIUsage {
usageResult := gjson.Get(payload, "usage")
if !usageResult.Exists() || !usageResult.IsObject() {
return nil
}
u := OpenAIUsage{
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
}
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
u.CacheReadInputTokens = int(cached.Int())
}
return &u
}
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
func (s *OpenAIGatewayService) bufferRawChatCompletions(
c *gin.Context,
resp *http.Response,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
}
return nil, fmt.Errorf("read upstream body: %w", err)
}
var ccResp apicompat.ChatCompletionsResponse
var usage OpenAIUsage
if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil {
usage = OpenAIUsage{
InputTokens: ccResp.Usage.PromptTokens,
OutputTokens: ccResp.Usage.CompletionTokens,
}
if ccResp.Usage.PromptTokensDetails != nil {
usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens
}
}
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
if ct := resp.Header.Get("Content-Type"); ct != "" {
c.Writer.Header().Set("Content-Type", ct)
} else {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.Write(respBody)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: billingModel,
UpstreamModel: upstreamModel,
ReasoningEffort: reasoningEffort,
ServiceTier: serviceTier,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
//
// - base 已是 /chat/completions原样返回
// - base 以 /v1 结尾:追加 /chat/completions
// - base 以其他版本段结尾(如 /v4追加 /chat/completions
// - 其他情况:追加 /v1/chat/completions
//
// 与 buildOpenAIResponsesURL 是姐妹函数。
func buildOpenAIChatCompletionsURL(base string) string {
return buildOpenAIEndpointURL(base, "/v1/chat/completions")
}