package service import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.uber.org/zap" ) // openaiCCRawAllowedHeaders 是 CC 直转路径专用的客户端 header 透传白名单。 // // **关键**:不能复用 openaiAllowedHeaders——后者含 Codex 客户端专属 header // (originator / session_id / x-codex-turn-state / x-codex-turn-metadata / conversation_id), // 这些在 ChatGPT OAuth 上游是必需的,但透传给 DeepSeek/Kimi/GLM 等第三方 // OpenAI 兼容上游会造成: // - 完全忽略(多数友好厂商)——隐性污染上游统计 // - 400 "unknown parameter"(严格上游)——可见错误 // // 这里仅放行通用 HTTP header;content-type / authorization / accept 由上下文 // 显式设置,不依赖透传。 // // 参见决策记录: // pensieve/short-term/maxims/dont-reuse-shared-headers-whitelist-across-different-upstream-trust-domains var openaiCCRawAllowedHeaders = map[string]bool{ "accept-language": true, "user-agent": true, } // forwardAsRawChatCompletions 直转客户端的 Chat Completions 请求到上游 // `{base_url}/v1/chat/completions`,**不**做 CC↔Responses 协议转换。 // // 适用场景:account.platform=openai && account.type=apikey && 上游已被探测确认 // 不支持 /v1/responses 端点(如 DeepSeek/Kimi/GLM/Qwen 等第三方 OpenAI 兼容上游)。 // // 与 ForwardAsChatCompletions 的关键差异: // // - 不调用 apicompat.ChatCompletionsToResponses,body 仅做模型 ID 改写 // - 上游 URL 拼到 /v1/chat/completions 而非 /v1/responses // - 流式响应 SSE 直接透传给客户端(上游 chunk 已是 CC 格式) // - 非流式响应 JSON 直接透传,仅按需提取 usage // - 不应用 codex OAuth transform(APIKey 路径无 OAuth) // - 不注入 prompt_cache_key(OAuth 专属机制) // // 调用入口:openai_gateway_chat_completions.go::ForwardAsChatCompletions // 在函数顶部按 openai_compat.ShouldUseResponsesAPI 分流。 func (s *OpenAIGatewayService) forwardAsRawChatCompletions( ctx context.Context, c *gin.Context, account *Account, body []byte, defaultMappedModel string, ) (*OpenAIForwardResult, error) { startTime := time.Now() // 1. Parse minimal fields needed for routing/billing originalModel := gjson.GetBytes(body, "model").String() if originalModel == "" { writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") return nil, fmt.Errorf("missing model in request") } clientStream := gjson.GetBytes(body, "stream").Bool() // 1b. Extract reasoning effort and service tier from the raw body before any transformation. reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel) serviceTier := extractOpenAIServiceTierFromBody(body) // 2. Resolve model mapping (same as ForwardAsChatCompletions) billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) // 3. Rewrite model in body (no protocol conversion) upstreamBody := body if upstreamModel != originalModel { upstreamBody = ReplaceModelInBody(body, upstreamModel) } // 4. Apply OpenAI fast policy on the CC body updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, upstreamBody) if policyErr != nil { var blocked *OpenAIFastBlockedError if errors.As(policyErr, &blocked) { MarkOpsClientBusinessLimited(c, OpsClientBusinessLimitedReasonLocalPolicyDenied) writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message) } return nil, policyErr } upstreamBody = updatedBody if clientStream { var usageErr error upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody) if usageErr != nil { return nil, fmt.Errorf("enable stream usage: %w", usageErr) } } logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), zap.String("billing_model", billingModel), zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), ) // 5. Build upstream request apiKey := account.GetOpenAIApiKey() if apiKey == "" { return nil, fmt.Errorf("account %d missing api_key", account.ID) } baseURL := account.GetOpenAIBaseURL() if baseURL == "" { baseURL = "https://api.openai.com" } validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, fmt.Errorf("invalid base_url: %w", err) } targetURL := buildOpenAIChatCompletionsURL(validatedURL) upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody)) releaseUpstreamCtx() if err != nil { return nil, fmt.Errorf("build upstream request: %w", err) } upstreamReq = upstreamReq.WithContext(WithHTTPUpstreamProfile(upstreamReq.Context(), HTTPUpstreamProfileOpenAI)) upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) if clientStream { upstreamReq.Header.Set("Accept", "text/event-stream") } else { upstreamReq.Header.Set("Accept", "application/json") } // 透传白名单中的客户端 header。详见 openaiCCRawAllowedHeaders 的设计说明。 for key, values := range c.Request.Header { lowerKey := strings.ToLower(key) if openaiCCRawAllowedHeaders[lowerKey] { for _, v := range values { upstreamReq.Header.Add(key, v) } } } customUA := account.GetOpenAIUserAgent() if customUA != "" { upstreamReq.Header.Set("user-agent", customUA) } // 6. Send request proxyURL := "" if account.Proxy != nil { proxyURL = account.Proxy.URL() } resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, Kind: "request_error", Message: safeErr, }) writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") return nil, fmt.Errorf("upstream request failed: %s", safeErr) } defer func() { _ = resp.Body.Close() }() // 7. Handle error response with failover if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "failover", Message: upstreamMsg, Detail: upstreamDetail, }) s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody, upstreamModel) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && (account.IsPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } return s.handleChatCompletionsErrorResponse(resp, c, account, billingModel) } // 8. Forward response if clientStream { return s.streamRawChatCompletions(c, resp, account, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime, len(body)) } return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) } // streamRawChatCompletions 透传上游 CC SSE 流到客户端,并提取 usage(包括 // 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。 // // usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。 // 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage, // 让级联代理或下游计费系统也能拿到完整用量。 func (s *OpenAIGatewayService) streamRawChatCompletions( c *gin.Context, resp *http.Response, account *Account, originalModel string, billingModel string, upstreamModel string, reasoningEffort *string, serviceTier *string, startTime time.Time, requestBodyLen int, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") headersWritten := false writeStreamHeaders := func() { if headersWritten { return } headersWritten = true if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.WriteHeader(http.StatusOK) } scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) var usage OpenAIUsage var firstTokenMs *int clientDisconnected := false clientOutputStarted := false pendingLines := make([]string, 0, 8) refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen) writeLine := func(line string) { if clientDisconnected { return } if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { pendingLines = append(pendingLines, line) return } if !clientOutputStarted { writeStreamHeaders() for _, pending := range pendingLines { if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { clientDisconnected = true logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", zap.Error(werr), zap.String("request_id", requestID), ) return } } pendingLines = pendingLines[:0] clientOutputStarted = true } if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { clientDisconnected = true logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", zap.Error(werr), zap.String("request_id", requestID), ) } } for scanner.Scan() { line := scanner.Text() refusalDetector.ObserveSSELine(line) if payload, ok := extractOpenAISSEDataLine(line); ok { trimmedPayload := strings.TrimSpace(payload) if trimmedPayload != "[DONE]" { usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload) if u := extractCCStreamUsage(payload); u != nil { usage = *u } if firstTokenMs == nil && !usageOnlyChunk { elapsed := int(time.Since(startTime).Milliseconds()) firstTokenMs = &elapsed } } } writeLine(line) if line == "" { if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } continue } if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } } if err := scanner.Err(); err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { logger.L().Warn("openai chat_completions raw: stream read error", zap.Error(err), zap.String("request_id", requestID), ) } } else if !clientDisconnected && !clientOutputStarted { if refusalDetector.IsSilentRefusal() { return nil, newOpenAISilentRefusalFailoverError(c, account, requestID) } if len(pendingLines) > 0 { writeStreamHeaders() for _, pending := range pendingLines { if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { clientDisconnected = true logger.L().Debug("openai chat_completions raw: client disconnected during final flush", zap.Error(werr), zap.String("request_id", requestID), ) break } } if !clientDisconnected { c.Writer.Flush() clientOutputStarted = true } } } return &OpenAIForwardResult{ RequestID: requestID, Usage: usage, Model: originalModel, BillingModel: billingModel, UpstreamModel: upstreamModel, ReasoningEffort: reasoningEffort, ServiceTier: serviceTier, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } // ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。 // usage 也会继续向下游透传,支持级联代理和下游计费系统。 func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) { updated, err := sjson.SetBytes(body, "stream_options.include_usage", true) if err != nil { return body, err } return updated, nil } func isOpenAIChatUsageOnlyStreamChunk(payload string) bool { if strings.TrimSpace(payload) == "" { return false } if !gjson.Get(payload, "usage").Exists() { return false } choices := gjson.Get(payload, "choices") return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0 } // extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。 // CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时), // 但上游可能在多个 chunk 中重复——总是用最新值。 func extractCCStreamUsage(payload string) *OpenAIUsage { usageResult := gjson.Get(payload, "usage") if !usageResult.Exists() || !usageResult.IsObject() { return nil } u := OpenAIUsage{ InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()), OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()), } if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() { u.CacheReadInputTokens = int(cached.Int()) } return &u } // bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。 func (s *OpenAIGatewayService) bufferRawChatCompletions( c *gin.Context, resp *http.Response, originalModel string, billingModel string, upstreamModel string, reasoningEffort *string, serviceTier *string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response") } return nil, fmt.Errorf("read upstream body: %w", err) } var ccResp apicompat.ChatCompletionsResponse var usage OpenAIUsage if err := json.Unmarshal(respBody, &ccResp); err == nil && ccResp.Usage != nil { usage = OpenAIUsage{ InputTokens: ccResp.Usage.PromptTokens, OutputTokens: ccResp.Usage.CompletionTokens, } if ccResp.Usage.PromptTokensDetails != nil { usage.CacheReadInputTokens = ccResp.Usage.PromptTokensDetails.CachedTokens } } if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } if ct := resp.Header.Get("Content-Type"); ct != "" { c.Writer.Header().Set("Content-Type", ct) } else { c.Writer.Header().Set("Content-Type", "application/json") } c.Writer.WriteHeader(http.StatusOK) _, _ = c.Writer.Write(respBody) return &OpenAIForwardResult{ RequestID: requestID, Usage: usage, Model: originalModel, BillingModel: billingModel, UpstreamModel: upstreamModel, ReasoningEffort: reasoningEffort, ServiceTier: serviceTier, Stream: false, Duration: time.Since(startTime), }, nil } // buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。 // // - base 已是 /chat/completions:原样返回 // - base 以 /v1 结尾:追加 /chat/completions // - base 以其他版本段结尾(如 /v4):追加 /chat/completions // - 其他情况:追加 /v1/chat/completions // // 与 buildOpenAIResponsesURL 是姐妹函数。 func buildOpenAIChatCompletionsURL(base string) string { return buildOpenAIEndpointURL(base, "/v1/chat/completions") }