fix(openai-images): 修复图片生成 n 参数透传

This commit is contained in:
wucm667 2026-05-20 11:28:28 +08:00
parent 91da815993
commit 2c14efeaa0
2 changed files with 38 additions and 19 deletions

View File

@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
tool, _ = sjson.SetBytes(tool, "action", action)
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
if shouldPassOpenAIImagesN(toolModel, parsed.N) {
tool, _ = sjson.SetBytes(tool, "n", parsed.N)
}
for _, field := range []struct {
path string
@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
return req, nil
}
func shouldPassOpenAIImagesN(model string, n int) bool {
if n <= 1 {
return false
}
return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3")
}
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
if gjson.GetBytes(payload, "type").String() != "response.completed" {
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
account.Type,
len(parsed.Uploads),
)
if parsed.N > 1 {
logger.LegacyPrintf(
"service.openai_gateway",
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
parsed.N,
requestModel,
parsed.Endpoint,
)
}
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
defer releaseUpstreamCtx()

View File

@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string)
return openAIImageTestSSEEvent{}, false
}
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
"X-Request-Id": []string{"req_img_123"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
"data: [DONE]\n\n",
)),
},
@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.NotNil(t, result)
require.Equal(t, "gpt-image-2", result.Model)
require.Equal(t, "gpt-image-2", result.UpstreamModel)
require.Equal(t, 1, result.ImageCount)
require.Equal(t, 3, result.ImageCount)
require.Equal(t, 11, result.Usage.InputTokens)
require.Equal(t, 22, result.Usage.OutputTokens)
require.Equal(t, 7, result.Usage.ImageOutputTokens)
@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int())
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3)
require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String())
require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String())
require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
}
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: "gpt-image-2",
@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int())
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
}
func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesGenerationsEndpoint,
Model: "dall-e-3",
Prompt: "draw a cat",
N: 2,
}
body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3")
require.NoError(t, err)
require.NotNil(t, body)
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String())
}
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
parsed := &OpenAIImagesRequest{
Endpoint: openAIImagesEditsEndpoint,