From 2c14efeaa00f788ab1efe84a2b4d20834f83ef27 Mon Sep 17 00:00:00 2001 From: wucm667 Date: Wed, 20 May 2026 11:28:28 +0800 Subject: [PATCH] =?UTF-8?q?fix(openai-images):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E7=94=9F=E6=88=90=20n=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E9=80=8F=E4=BC=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/openai_images_responses.go | 20 +++++----- .../internal/service/openai_images_test.go | 37 ++++++++++++++----- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index c89c2aaf..56272c26 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st tool := []byte(`{"type":"image_generation","action":"","model":""}`) tool, _ = sjson.SetBytes(tool, "action", action) tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel)) + if shouldPassOpenAIImagesN(toolModel, parsed.N) { + tool, _ = sjson.SetBytes(tool, "n", parsed.N) + } for _, field := range []struct { path string @@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st return req, nil } +func shouldPassOpenAIImagesN(model string, n int) bool { + if n <= 1 { + return false + } + return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3") +} + func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) { if gjson.GetBytes(payload, "type").String() != "response.completed" { return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type") @@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( account.Type, len(parsed.Uploads), ) - if parsed.N > 1 { - logger.LegacyPrintf( - "service.openai_gateway", - "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", - parsed.N, - requestModel, - parsed.Endpoint, - ) - } - upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) defer releaseUpstreamCtx() diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 35789d21..d47c52ca 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) return openAIImageTestSSEEvent{}, false } -func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { +func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) { gin.SetMode(gin.TestMode) - body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`) req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { "X-Request-Id": []string{"req_img_123"}, }, Body: io.NopCloser(strings.NewReader( - "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + "data: [DONE]\n\n", )), }, @@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.NotNil(t, result) require.Equal(t, "gpt-image-2", result.Model) require.Equal(t, "gpt-image-2", result.UpstreamModel) - require.Equal(t, 1, result.ImageCount) + require.Equal(t, 3, result.ImageCount) require.Equal(t, 11, result.Usage.InputTokens) require.Equal(t, 22, result.Usage.OutputTokens) require.Equal(t, 7, result.Usage.ImageOutputTokens) @@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String()) require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String()) - require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists()) + require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int()) require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) - require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) - require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) + require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3) + require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String()) + require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String()) + require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) + require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String()) } func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) { @@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } -func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { +func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) { parsed := &OpenAIImagesRequest{ Endpoint: openAIImagesGenerationsEndpoint, Model: "gpt-image-2", @@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") require.NoError(t, err) require.NotNil(t, body) - require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int()) require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) } +func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: "dall-e-3", + Prompt: "draw a cat", + N: 2, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String()) +} + func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { parsed := &OpenAIImagesRequest{ Endpoint: openAIImagesEditsEndpoint,