fix(openai): handle versioned compatible base URLs

This commit is contained in:
wucm667 2026-05-13 11:25:15 +08:00
parent 18790386a7
commit 679c0865a0
7 changed files with 142 additions and 25 deletions

View File

@ -0,0 +1,78 @@
package service
import (
"net/url"
"strings"
)
func buildOpenAIEndpointURL(base string, endpoint string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
endpoint = "/" + strings.TrimLeft(strings.TrimSpace(endpoint), "/")
relative := strings.TrimPrefix(endpoint, "/v1")
if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
return normalized
}
if openAIBaseURLHasVersionSuffix(normalized) {
return normalized + relative
}
return normalized + endpoint
}
func openAIBaseURLHasVersionSuffix(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
pathValue := ""
if parsed, err := url.Parse(trimmed); err == nil && parsed.Scheme != "" && parsed.Host != "" {
pathValue = parsed.Path
} else if slash := strings.Index(trimmed, "/"); slash >= 0 {
pathValue = trimmed[slash:]
}
pathValue = strings.TrimRight(pathValue, "/")
if pathValue == "" {
return false
}
lastSlash := strings.LastIndex(pathValue, "/")
segment := pathValue
if lastSlash >= 0 {
segment = pathValue[lastSlash+1:]
}
return isOpenAIAPIVersionSegment(segment)
}
func isOpenAIAPIVersionSegment(segment string) bool {
s := strings.ToLower(strings.TrimSpace(segment))
if len(s) < 2 || s[0] != 'v' || !isASCIIDigit(s[1]) {
return false
}
i := 1
for i < len(s) && isASCIIDigit(s[i]) {
i++
}
if i == len(s) {
return true
}
if s[i] == '.' {
i++
if i == len(s) || !isASCIIDigit(s[i]) {
return false
}
for i < len(s) && isASCIIDigit(s[i]) {
i++
}
return i == len(s)
}
suffix := s[i:]
return strings.HasPrefix(suffix, "alpha") ||
strings.HasPrefix(suffix, "beta") ||
strings.HasPrefix(suffix, "preview")
}
func isASCIIDigit(b byte) bool {
return b >= '0' && b <= '9'
}

View File

@ -247,6 +247,16 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if account.Type == AccountTypeAPIKey &&
openai_compat.ResolveResponsesSupport(account.Extra) == openai_compat.ResponsesSupportUnknown &&
!isResponsesEndpointSupportedByStatus(resp.StatusCode) {
logger.L().Info("openai chat_completions: /responses unsupported, falling back to raw chat completions",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", resp.StatusCode),
zap.String("upstream_message", upstreamMsg),
)
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
}
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {

View File

@ -422,16 +422,10 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions(
//
// - base 已是 /chat/completions原样返回
// - base 以 /v1 结尾:追加 /chat/completions
// - base 以其他版本段结尾(如 /v4追加 /chat/completions
// - 其他情况:追加 /v1/chat/completions
//
// 与 buildOpenAIResponsesURL 是姐妹函数。
func buildOpenAIChatCompletionsURL(base string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
if strings.HasSuffix(normalized, "/chat/completions") {
return normalized
}
if strings.HasSuffix(normalized, "/v1") {
return normalized + "/chat/completions"
}
return normalized + "/v1/chat/completions"
return buildOpenAIEndpointURL(base, "/v1/chat/completions")
}

View File

@ -36,6 +36,7 @@ func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
// 第三方上游常见形式
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
{"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/chat/completions"},
// 带空白字符
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
}
@ -64,6 +65,7 @@ func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
{"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/responses"},
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
}
@ -193,6 +195,49 @@ func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testi
require.NoError(t, upstream.lastReq.Context().Err())
}
func TestForwardAsChatCompletions_UnknownResponsesSupportFallbackUsesVersionedChatURL(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"glm-4.5-air","messages":[{"role":"user","content":"hello"}],"stream":false}`)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstream := &httpUpstreamRecorder{responses: []*http.Response{
{
StatusCode: http.StatusNotFound,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"not found"}}`)),
},
{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_raw_fallback"}},
Body: io.NopCloser(strings.NewReader(
`{"id":"chatcmpl_1","object":"chat.completion","model":"glm-4.5-air","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
)),
},
}}
svc := &OpenAIGatewayService{
cfg: rawChatCompletionsTestConfig(),
httpUpstream: upstream,
}
account := rawChatCompletionsTestAccount()
account.Credentials["base_url"] = "https://open.bigmodel.cn/api/paas/v4"
result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.Len(t, upstream.requests, 2)
require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/responses", upstream.requests[0].URL.String())
require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/chat/completions", upstream.requests[1].URL.String())
require.Equal(t, http.StatusOK, rec.Code)
require.Contains(t, rec.Body.String(), `"content":"ok"`)
}
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
t.Parallel()

View File

@ -4955,17 +4955,11 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。
// - base 以 /v1 结尾:追加 /responses
// - base 以其他版本段结尾(如 /v4追加 /responses
// - base 已是 /responses原样返回
// - 其他情况:追加 /v1/responses
func buildOpenAIResponsesURL(base string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
if strings.HasSuffix(normalized, "/responses") {
return normalized
}
if strings.HasSuffix(normalized, "/v1") {
return normalized + "/responses"
}
return normalized + "/v1/responses"
return buildOpenAIEndpointURL(base, "/v1/responses")
}
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {

View File

@ -795,15 +795,7 @@ func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
}
func buildOpenAIImagesURL(base string, endpoint string) string {
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1")
if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
return normalized
}
if strings.HasSuffix(normalized, "/v1") {
return normalized + relative
}
return normalized + endpoint
return buildOpenAIEndpointURL(base, endpoint)
}
func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {

View File

@ -418,6 +418,10 @@ func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
"https://image-upstream.example/v1/images/generations",
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
)
require.Equal(t,
"https://open.bigmodel.cn/api/paas/v4/images/generations",
buildOpenAIImagesURL("https://open.bigmodel.cn/api/paas/v4", openAIImagesGenerationsEndpoint),
)
require.Equal(t,
"https://image-upstream.example/v1/images/edits",
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),