fix(openai-images): 修复图片生成 n 参数透传
This commit is contained in:
parent
91da815993
commit
2c14efeaa0
@ -262,6 +262,9 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
||||
tool := []byte(`{"type":"image_generation","action":"","model":""}`)
|
||||
tool, _ = sjson.SetBytes(tool, "action", action)
|
||||
tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
|
||||
if shouldPassOpenAIImagesN(toolModel, parsed.N) {
|
||||
tool, _ = sjson.SetBytes(tool, "n", parsed.N)
|
||||
}
|
||||
|
||||
for _, field := range []struct {
|
||||
path string
|
||||
@ -302,6 +305,13 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func shouldPassOpenAIImagesN(model string, n int) bool {
|
||||
if n <= 1 {
|
||||
return false
|
||||
}
|
||||
return !strings.EqualFold(strings.TrimSpace(model), "dall-e-3")
|
||||
}
|
||||
|
||||
func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
|
||||
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
||||
return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
|
||||
@ -957,16 +967,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
account.Type,
|
||||
len(parsed.Uploads),
|
||||
)
|
||||
if parsed.N > 1 {
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
|
||||
parsed.N,
|
||||
requestModel,
|
||||
parsed.Endpoint,
|
||||
)
|
||||
}
|
||||
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
defer releaseUpstreamCtx()
|
||||
|
||||
|
||||
@ -474,9 +474,9 @@ func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string)
|
||||
return openAIImageTestSSEEvent{}, false
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":3}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@ -497,7 +497,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
"X-Request-Id": []string{"req_img_123"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":3}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMQ==\",\"revised_prompt\":\"draw a cat 1\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMg==\",\"revised_prompt\":\"draw a cat 2\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"},{\"type\":\"image_generation_call\",\"result\":\"aW1hZ2UtMw==\",\"revised_prompt\":\"draw a cat 3\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||
"data: [DONE]\n\n",
|
||||
)),
|
||||
},
|
||||
@ -520,7 +520,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "gpt-image-2", result.Model)
|
||||
require.Equal(t, "gpt-image-2", result.UpstreamModel)
|
||||
require.Equal(t, 1, result.ImageCount)
|
||||
require.Equal(t, 3, result.ImageCount)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 22, result.Usage.OutputTokens)
|
||||
require.Equal(t, 7, result.Usage.ImageOutputTokens)
|
||||
@ -540,13 +540,17 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||
require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
|
||||
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
|
||||
require.Equal(t, int64(3), gjson.GetBytes(upstream.lastBody, "tools.0.n").Int())
|
||||
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
require.Len(t, gjson.Get(rec.Body.String(), "data").Array(), 3)
|
||||
require.Equal(t, "aW1hZ2UtMQ==", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||
require.Equal(t, "aW1hZ2UtMg==", gjson.Get(rec.Body.String(), "data.1.b64_json").String())
|
||||
require.Equal(t, "aW1hZ2UtMw==", gjson.Get(rec.Body.String(), "data.2.b64_json").String())
|
||||
require.Equal(t, "draw a cat 1", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
@ -1112,7 +1116,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
|
||||
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
|
||||
func TestBuildOpenAIImagesResponsesRequest_PassesThroughNForMultiImageModels(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesGenerationsEndpoint,
|
||||
Model: "gpt-image-2",
|
||||
@ -1123,11 +1127,26 @@ func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *t
|
||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, body)
|
||||
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||
require.Equal(t, int64(2), gjson.GetBytes(body, "tools.0.n").Int())
|
||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
||||
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_DoesNotPassNForDallE3(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesGenerationsEndpoint,
|
||||
Model: "dall-e-3",
|
||||
Prompt: "draw a cat",
|
||||
N: 2,
|
||||
}
|
||||
|
||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "dall-e-3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, body)
|
||||
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||
require.Equal(t, "dall-e-3", gjson.GetBytes(body, "tools.0.model").String())
|
||||
}
|
||||
|
||||
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
||||
parsed := &OpenAIImagesRequest{
|
||||
Endpoint: openAIImagesEditsEndpoint,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user