diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go new file mode 100644 index 00000000..8fb82ef4 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go @@ -0,0 +1,719 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ResponsesToChatCompletionsRequest converts a Responses API request into a +// Chat Completions request for upstreams that only implement +// /v1/chat/completions. +func ResponsesToChatCompletionsRequest(req *ResponsesRequest) (*ChatCompletionsRequest, error) { + if req == nil { + return nil, fmt.Errorf("responses request is nil") + } + + messages, err := responsesInputToChatMessages(req.Instructions, req.Input) + if err != nil { + return nil, err + } + + out := &ChatCompletionsRequest{ + Model: req.Model, + Messages: messages, + MaxCompletionTokens: req.MaxOutputTokens, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + ServiceTier: req.ServiceTier, + } + if req.Reasoning != nil { + out.ReasoningEffort = req.Reasoning.Effort + } + if len(req.Tools) > 0 { + out.Tools = responsesToolsToChatTools(req.Tools) + } + if len(req.ToolChoice) > 0 { + out.ToolChoice = responsesToolChoiceToChatToolChoice(req.ToolChoice) + } + + return out, nil +} + +func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) ([]ChatMessage, error) { + var messages []ChatMessage + if strings.TrimSpace(instructions) != "" { + content, _ := json.Marshal(instructions) + messages = append(messages, ChatMessage{ + Role: "system", + Content: content, + }) + } + + inputRaw = bytesTrimSpace(inputRaw) + if len(inputRaw) == 0 || string(inputRaw) == "null" { + return messages, nil + } + + var inputText string + if err := json.Unmarshal(inputRaw, &inputText); err == nil { + content, _ := json.Marshal(inputText) + messages = append(messages, ChatMessage{ + Role: "user", + Content: content, + }) + return messages, nil + } + + var rawItems []json.RawMessage + if err := json.Unmarshal(inputRaw, &rawItems); err != nil { + return nil, fmt.Errorf("parse responses input: %w", err) + } + + for _, raw := range rawItems { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + continue + } + + var item map[string]json.RawMessage + if err := json.Unmarshal(raw, &item); err != nil { + var text string + if textErr := json.Unmarshal(raw, &text); textErr == nil { + content, _ := json.Marshal(text) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + return nil, fmt.Errorf("parse responses input item: %w", err) + } + + role := rawString(item["role"]) + itemType := rawString(item["type"]) + switch itemType { + case "function_call": + arguments := rawString(item["arguments"]) + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + messages = append(messages, ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{{ + ID: rawString(item["call_id"]), + Type: "function", + Function: ChatFunctionCall{ + Name: rawString(item["name"]), + Arguments: arguments, + }, + }}, + }) + continue + case "function_call_output": + content, _ := json.Marshal(rawString(item["output"])) + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: rawString(item["call_id"]), + Content: content, + }) + continue + case "input_text", "text": + content, _ := json.Marshal(rawString(item["text"])) + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + case "input_image": + content, err := chatContentFromSingleResponsesPart(itemType, item) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{Role: "user", Content: content}) + continue + } + + if role == "" { + role = "user" + } + content := item["content"] + if len(bytesTrimSpace(content)) == 0 { + if text := rawString(item["text"]); text != "" { + content, _ = json.Marshal(text) + } + } + chatContent, err := responsesContentToChatContent(content, role) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{ + Role: role, + Content: chatContent, + }) + } + + return messages, nil +} + +func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + empty, _ := json.Marshal("") + return empty, nil + } + + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return raw, nil + } + + var rawParts []json.RawMessage + if err := json.Unmarshal(raw, &rawParts); err == nil { + return responsesContentPartsToChatContent(rawParts, role) + } + + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err == nil { + return chatContentFromSingleResponsesPart(rawString(obj["type"]), obj) + } + + return raw, nil +} + +func responsesContentPartsToChatContent(rawParts []json.RawMessage, role string) (json.RawMessage, error) { + var textParts []string + var chatParts []ChatContentPart + hasNonText := false + + for _, rawPart := range rawParts { + var part map[string]json.RawMessage + if err := json.Unmarshal(rawPart, &part); err != nil { + continue + } + partType := rawString(part["type"]) + switch partType { + case "input_text", "output_text", "text", "": + text := rawString(part["text"]) + if text == "" { + continue + } + textParts = append(textParts, text) + chatParts = append(chatParts, ChatContentPart{Type: "text", Text: text}) + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + if imageURL == "" { + continue + } + hasNonText = true + chatParts = append(chatParts, ChatContentPart{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }) + } + } + + if !hasNonText { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if role != "user" { + joined, _ := json.Marshal(strings.Join(textParts, "\n\n")) + return joined, nil + } + if len(chatParts) == 0 { + empty, _ := json.Marshal("") + return empty, nil + } + return json.Marshal(chatParts) +} + +func chatContentFromSingleResponsesPart(partType string, part map[string]json.RawMessage) (json.RawMessage, error) { + switch partType { + case "input_image", "image_url": + imageURL := rawString(part["image_url"]) + if imageURL == "" { + imageURL = rawNestedString(part["image_url"], "url") + } + return json.Marshal([]ChatContentPart{{ + Type: "image_url", + ImageURL: &ChatImageURL{URL: imageURL}, + }}) + default: + return json.Marshal(rawString(part["text"])) + } +} + +func responsesToolsToChatTools(tools []ResponsesTool) []ChatTool { + out := make([]ChatTool, 0, len(tools)) + for _, tool := range tools { + if tool.Type != "function" { + continue + } + out = append(out, ChatTool{ + Type: "function", + Function: &ChatFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + Strict: tool.Strict, + }, + }) + } + return out +} + +func responsesToolChoiceToChatToolChoice(raw json.RawMessage) json.RawMessage { + var choice map[string]json.RawMessage + if err := json.Unmarshal(raw, &choice); err != nil { + return raw + } + if rawString(choice["type"]) != "function" { + return raw + } + name := rawString(choice["name"]) + if name == "" { + name = rawNestedString(choice["function"], "name") + } + if name == "" { + return raw + } + out, err := json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{ + "name": name, + }, + }) + if err != nil { + return raw + } + return out +} + +// ChatCompletionsResponseToResponses converts a non-streaming Chat Completions +// response into a Responses API response. +func ChatCompletionsResponseToResponses(resp *ChatCompletionsResponse, model string) *ResponsesResponse { + id := "" + if resp != nil { + id = resp.ID + } + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: model, + Status: "completed", + } + if resp == nil { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + return out + } + if out.Model == "" { + out.Model = resp.Model + } + + if len(resp.Choices) > 0 { + choice := resp.Choices[0] + out.Output = chatMessageToResponsesOutput(choice.Message) + if choice.FinishReason == "length" { + out.Status = "incomplete" + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + } + if len(out.Output) == 0 { + out.Output = []ResponsesOutput{emptyResponsesMessageOutput()} + } + if resp.Usage != nil { + out.Usage = ChatUsageToResponsesUsage(resp.Usage) + } + return out +} + +func chatMessageToResponsesOutput(message ChatMessage) []ResponsesOutput { + var outputs []ResponsesOutput + if message.ReasoningContent != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: message.ReasoningContent, + }}, + }) + } + + text := chatMessageContentText(message.Content) + if text != "" || len(message.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: text, + }}, + Status: "completed", + }) + } + + for _, toolCall := range message.ToolCalls { + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + + return outputs +} + +func emptyResponsesMessageOutput() ResponsesOutput { + return ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + } +} + +func chatMessageContentText(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var text string + if err := json.Unmarshal(raw, &text); err == nil { + return text + } + var parts []ChatContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, part := range parts { + if part.Type == "text" && part.Text != "" { + texts = append(texts, part.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// ChatUsageToResponsesUsage converts Chat Completions token usage to Responses +// usage shape. +func ChatUsageToResponsesUsage(usage *ChatUsage) *ResponsesUsage { + if usage == nil { + return nil + } + out := &ResponsesUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if out.TotalTokens == 0 { + out.TotalTokens = out.InputTokens + out.OutputTokens + } + if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 { + out.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: usage.PromptTokensDetails.CachedTokens, + } + } + return out +} + +// ChatCompletionsToResponsesStreamState tracks state while converting Chat +// Completions SSE chunks into Responses SSE events. +type ChatCompletionsToResponsesStreamState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + CreatedSent bool + CompletedSent bool + + MessageItemID string + Text strings.Builder + Reasoning strings.Builder + ToolCalls map[int]*ChatToolCall + + FinishReason string + Usage *ResponsesUsage +} + +// NewChatCompletionsToResponsesStreamState returns an initialized stream state. +func NewChatCompletionsToResponsesStreamState(model string) *ChatCompletionsToResponsesStreamState { + return &ChatCompletionsToResponsesStreamState{ + ResponseID: generateResponsesID(), + Model: model, + Created: time.Now().Unix(), + ToolCalls: make(map[int]*ChatToolCall), + } +} + +// ChatCompletionsChunkToResponsesEvents converts one Chat Completions stream +// chunk into zero or more Responses stream events. +func ChatCompletionsChunkToResponsesEvents( + chunk *ChatCompletionsChunk, + state *ChatCompletionsToResponsesStreamState, +) []ResponsesStreamEvent { + if chunk == nil || state == nil { + return nil + } + if chunk.ID != "" { + state.ResponseID = chunk.ID + } + if state.Model == "" && chunk.Model != "" { + state.Model = chunk.Model + } + if chunk.Usage != nil { + state.Usage = ChatUsageToResponsesUsage(chunk.Usage) + } + + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil { + events = append(events, ensureChatToResponsesMessageItem(state)...) + _, _ = state.Text.WriteString(*choice.Delta.Content) + events = append(events, chatToResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Delta: *choice.Delta.Content, + ItemID: state.MessageItemID, + })) + } + if choice.Delta.ReasoningContent != nil { + _, _ = state.Reasoning.WriteString(*choice.Delta.ReasoningContent) + events = append(events, chatToResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: 0, + SummaryIndex: 0, + Delta: *choice.Delta.ReasoningContent, + })) + } + for _, toolCall := range choice.Delta.ToolCalls { + idx := 0 + if toolCall.Index != nil { + idx = *toolCall.Index + } + stored, ok := state.ToolCalls[idx] + if !ok { + copyCall := toolCall + if copyCall.ID == "" { + copyCall.ID = generateItemID() + } + copyCall.Type = "function" + state.ToolCalls[idx] = ©Call + stored = ©Call + events = append(events, chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Item: &ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: stored.ID, + Name: stored.Function.Name, + Status: "in_progress", + }, + })) + } else { + if toolCall.ID != "" { + stored.ID = toolCall.ID + } + if toolCall.Function.Name != "" { + stored.Function.Name = toolCall.Function.Name + } + } + if toolCall.Function.Arguments != "" { + stored.Function.Arguments += toolCall.Function.Arguments + events = append(events, chatToResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: idx + 1, + Delta: toolCall.Function.Arguments, + CallID: stored.ID, + Name: stored.Function.Name, + })) + } + } + if choice.FinishReason != nil && *choice.FinishReason != "" { + state.FinishReason = *choice.FinishReason + } + } + + return events +} + +// FinalizeChatCompletionsResponsesStream emits terminal Responses events. +func FinalizeChatCompletionsResponsesStream(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state == nil || state.CompletedSent { + return nil + } + var events []ResponsesStreamEvent + events = append(events, ensureChatToResponsesCreated(state)...) + if state.MessageItemID != "" { + events = append(events, chatToResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: 0, + ContentIndex: 0, + Text: state.Text.String(), + ItemID: state.MessageItemID, + })) + events = append(events, chatToResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "completed", + }, + })) + } + + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + if state.FinishReason == "length" { + status = "incomplete" + incompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + state.CompletedSent = true + events = append(events, chatToResponsesEvent(state, "response.completed", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: state.chatOutput(), + Usage: state.Usage, + IncompleteDetails: incompleteDetails, + }, + })) + return events +} + +func ensureChatToResponsesCreated(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.CreatedSent { + return nil + } + state.CreatedSent = true + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.created", &ResponsesStreamEvent{ + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + })} +} + +func ensureChatToResponsesMessageItem(state *ChatCompletionsToResponsesStreamState) []ResponsesStreamEvent { + if state.MessageItemID != "" { + return nil + } + state.MessageItemID = generateItemID() + return []ResponsesStreamEvent{chatToResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: 0, + Item: &ResponsesOutput{ + Type: "message", + ID: state.MessageItemID, + Role: "assistant", + Status: "in_progress", + }, + })} +} + +func (state *ChatCompletionsToResponsesStreamState) chatOutput() []ResponsesOutput { + var outputs []ResponsesOutput + if state.Reasoning.Len() > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: state.Reasoning.String(), + }}, + }) + } + if state.MessageItemID != "" || len(state.ToolCalls) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: nonEmpty(state.MessageItemID, generateItemID()), + Role: "assistant", + Content: []ResponsesContentPart{{ + Type: "output_text", + Text: state.Text.String(), + }}, + Status: "completed", + }) + } + for i := 0; i < len(state.ToolCalls); i++ { + toolCall, ok := state.ToolCalls[i] + if !ok || toolCall == nil { + continue + } + arguments := toolCall.Function.Arguments + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toolCall.ID, + Name: toolCall.Function.Name, + Arguments: arguments, + Status: "completed", + }) + } + return outputs +} + +func chatToResponsesEvent( + state *ChatCompletionsToResponsesStreamState, + eventType string, + template *ResponsesStreamEvent, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func rawString(raw json.RawMessage) string { + raw = bytesTrimSpace(raw) + if len(raw) == 0 || string(raw) == "null" { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + return "" +} + +func rawNestedString(raw json.RawMessage, key string) string { + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + return "" + } + return rawString(obj[key]) +} + +func bytesTrimSpace(raw json.RawMessage) json.RawMessage { + return json.RawMessage(strings.TrimSpace(string(raw))) +} + +func nonEmpty(value, fallback string) string { + if value != "" { + return value + } + return fallback +} diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback.go b/backend/internal/service/openai_gateway_responses_chat_fallback.go new file mode 100644 index 00000000..1d28a9c2 --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_chat_fallback.go @@ -0,0 +1,428 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// forwardResponsesViaRawChatCompletions serves /v1/responses clients through an +// upstream that only supports /v1/chat/completions. +func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "Failed to parse request body", + }, + }) + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := strings.TrimSpace(responsesReq.Model) + if originalModel == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "model is required", + }, + }) + return nil, fmt.Errorf("missing model in request") + } + + clientStream := responsesReq.Stream + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, originalModel) + serviceTier := extractOpenAIServiceTierFromBody(body) + + chatReq, err := apicompat.ResponsesToChatCompletionsRequest(&responsesReq) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + }, + }) + return nil, fmt.Errorf("convert responses to chat completions: %w", err) + } + + billingModel := resolveOpenAIForwardModel(account, originalModel, "") + upstreamModel := normalizeOpenAIModelForUpstream(account, billingModel) + chatReq.Model = upstreamModel + if clientStream { + chatReq.StreamOptions = &apicompat.ChatStreamOptions{IncludeUsage: true} + } + + chatBody, err := json.Marshal(chatReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions fallback request: %w", err) + } + chatBody, err = s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, chatBody) + if err != nil { + var blocked *OpenAIFastBlockedError + if errors.As(err, &blocked) { + writeOpenAIFastPolicyBlockedResponse(c, blocked) + } + return nil, err + } + if serviceTier == nil { + serviceTier = extractOpenAIServiceTierFromBody(chatBody) + } + + logger.L().Debug("openai responses: forwarding via raw chat completions", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), + zap.Bool("stream", clientStream), + ) + + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return nil, fmt.Errorf("account %d missing api_key", account.ID) + } + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + targetURL := buildOpenAIChatCompletionsURL(validatedURL) + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(chatBody)) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + if clientStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiCCRawAllowedHeaders[lowerKey] { + for _, v := range values { + upstreamReq.Header.Add(key, v) + } + } + } + if customUA := account.GetOpenAIUserAgent(); customUA != "" { + upstreamReq.Header.Set("user-agent", customUA) + } + + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleErrorResponse(ctx, resp, c, account, chatBody) + } + + if clientStream { + return s.streamChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + } + return s.bufferChatCompletionsAsResponses(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) +} + +func (s *OpenAIGatewayService) bufferChatCompletionsAsResponses( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Failed to read upstream response", + }, + }) + } + return nil, fmt.Errorf("read upstream body: %w", err) + } + + var ccResp apicompat.ChatCompletionsResponse + if err := json.Unmarshal(respBody, &ccResp); err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Failed to parse upstream response", + }, + }) + return nil, fmt.Errorf("parse chat completions response: %w", err) + } + responsesResp := apicompat.ChatCompletionsResponseToResponses(&ccResp, originalModel) + + usage := OpenAIUsage{} + if parsed, ok := extractOpenAIUsageFromJSONBytes(respBody); ok { + usage = parsed + } + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, responsesResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +func (s *OpenAIGatewayService) streamChatCompletionsAsResponses( + c *gin.Context, + resp *http.Response, + originalModel string, + billingModel string, + upstreamModel string, + reasoningEffort *string, + serviceTier *string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + } + + state := apicompat.NewChatCompletionsToResponsesStreamState(originalModel) + var usage OpenAIUsage + var firstTokenMs *int + clientDisconnected := false + sawDone := false + + writeEvents := func(events []apicompat.ResponsesStreamEvent) { + if clientDisconnected || len(events) == 0 { + return + } + writeStreamHeaders() + for _, event := range events { + sse, err := apicompat.ResponsesEventToSSE(event) + if err != nil { + logger.L().Warn("openai responses chat fallback: failed to marshal stream event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + clientDisconnected = true + logger.L().Debug("openai responses chat fallback: client disconnected, continuing to drain upstream for billing", + zap.Error(err), + zap.String("request_id", requestID), + ) + return + } + } + c.Writer.Flush() + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + for scanner.Scan() { + line := scanner.Text() + payload, ok := extractOpenAISSEDataLine(line) + if !ok { + continue + } + payload = strings.TrimSpace(payload) + if payload == "" { + continue + } + if payload == "[DONE]" { + sawDone = true + break + } + + if u := extractCCStreamUsage(payload); u != nil { + usage = *u + } + + var chunk apicompat.ChatCompletionsChunk + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + logger.L().Warn("openai responses chat fallback: failed to parse chat stream chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if firstTokenMs == nil && !isOpenAIChatUsageOnlyStreamChunk(payload) && chatChunkStartsResponsesOutput(&chunk) { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + writeEvents(apicompat.ChatCompletionsChunkToResponsesEvents(&chunk, state)) + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai responses chat fallback: stream read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, fmt.Errorf("stream usage incomplete: %w", err) + } + + writeEvents(apicompat.FinalizeChatCompletionsResponsesStream(state)) + if !clientDisconnected { + writeStreamHeaders() + if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { + clientDisconnected = true + } + if !clientDisconnected { + c.Writer.Flush() + } + } + if !sawDone { + logger.L().Debug("openai responses chat fallback: upstream stream ended without done sentinel", + zap.String("request_id", requestID), + ) + } + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, + ReasoningEffort: reasoningEffort, + ServiceTier: serviceTier, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func chatChunkStartsResponsesOutput(chunk *apicompat.ChatCompletionsChunk) bool { + if chunk == nil { + return false + } + for _, choice := range chunk.Choices { + if choice.Delta.Content != nil || choice.Delta.ReasoningContent != nil || len(choice.Delta.ToolCalls) > 0 { + return true + } + } + return false +} diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback_test.go b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go new file mode 100644 index 00000000..78df2202 --- /dev/null +++ b/backend/internal/service/openai_gateway_responses_chat_fallback_test.go @@ -0,0 +1,145 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestForwardResponses_ForceChatCompletionsRoutesNonStreamingToChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_chat_json"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"chatcmpl_json","object":"chat.completion","model":"gpt-5.4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":2,"total_tokens":5,"prompt_tokens_details":{"cached_tokens":1}}}`, + )), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) + require.Equal(t, "response", gjson.Get(rec.Body.String(), "object").String()) + require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String()) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.False(t, result.Stream) +} + +func TestForwardResponses_ForceChatCompletionsRoutesStreamingToChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"he"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"llo"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + "", + `data: {"id":"chatcmpl_stream","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":4,"completion_tokens":3,"total_tokens":7}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_resp_chat_stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.Forward(context.Background(), c, forceChatResponsesFallbackAccount(), body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool()) + require.Contains(t, rec.Body.String(), "event: response.output_text.delta") + require.Contains(t, rec.Body.String(), `"delta":"he"`) + require.Contains(t, rec.Body.String(), "event: response.completed") + require.Contains(t, rec.Body.String(), `"input_tokens":4`) + require.Contains(t, rec.Body.String(), "data: [DONE]") + require.Equal(t, 4, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) +} + +func TestForwardResponses_AutoSupportedAccountStillUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_resp_native"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"resp_native","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"ok"}],"status":"completed"}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}`, + )), + }} + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + account.Extra = map[string]any{ + openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeAuto), + openai_compat.ExtraKeyResponsesSupported: true, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "http://upstream.example/v1/responses", upstream.lastReq.URL.String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) + require.False(t, gjson.GetBytes(upstream.lastBody, "messages").Exists()) + require.Equal(t, "ok", gjson.Get(rec.Body.String(), "output.0.content.0.text").String()) +} + +func forceChatResponsesFallbackAccount() *Account { + account := rawChatCompletionsTestAccount() + account.Extra = map[string]any{ + openai_compat.ExtraKeyResponsesMode: string(openai_compat.ResponsesSupportModeForceChatCompletions), + } + return account +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index cfaf5bff..8554378c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -24,6 +24,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/cespare/xxhash/v2" @@ -2018,6 +2019,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel + + if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { + return s.forwardResponsesViaRawChatCompletions(ctx, c, account, body) + } + compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body) setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge)