Merge branch 'Wei-Shaw:main' into fix/openai-gateway-content-session-hash-fallback
This commit is contained in:
commit
77ba9e728d
@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
|||||||
googleError(c, http.StatusBadGateway, err.Error())
|
googleError(c, http.StatusBadGateway, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if shouldFallbackGeminiModels(res) {
|
if shouldFallbackGeminiModel(modelName, res) {
|
||||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
|
||||||
|
if shouldFallbackGeminiModels(res) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if res == nil || res.StatusCode != http.StatusNotFound {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return gemini.HasFallbackModel(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||||
//
|
//
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||||
|
require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||||
|
require.False(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
res := &service.UpstreamHTTPResult{
|
||||||
|
StatusCode: http.StatusForbidden,
|
||||||
|
Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}},
|
||||||
|
Body: []byte("insufficient authentication scopes"),
|
||||||
|
}
|
||||||
|
require.True(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||||
|
}
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
DisplayName string `json:"displayName,omitempty"`
|
DisplayName string `json:"displayName,omitempty"`
|
||||||
@ -23,10 +25,27 @@ func DefaultModels() []Model {
|
|||||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||||
|
{Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods},
|
||||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HasFallbackModel(model string) bool {
|
||||||
|
trimmed := strings.TrimSpace(model)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(trimmed, "models/") {
|
||||||
|
trimmed = "models/" + trimmed
|
||||||
|
}
|
||||||
|
for _, model := range DefaultModels() {
|
||||||
|
if model.Name == trimmed {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func FallbackModelsList() ModelsListResponse {
|
func FallbackModelsList() ModelsListResponse {
|
||||||
return ModelsListResponse{Models: DefaultModels()}
|
return ModelsListResponse{Models: DefaultModels()}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ package gemini
|
|||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
models := DefaultModels()
|
models := DefaultModels()
|
||||||
@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
|||||||
|
|
||||||
required := []string{
|
required := []string{
|
||||||
"models/gemini-2.5-flash-image",
|
"models/gemini-2.5-flash-image",
|
||||||
|
"models/gemini-3.1-pro-preview-customtools",
|
||||||
"models/gemini-3.1-flash-image",
|
"models/gemini-3.1-flash-image",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if !HasFallbackModel("gemini-3.1-pro-preview-customtools") {
|
||||||
|
t.Fatalf("expected customtools model to exist in fallback catalog")
|
||||||
|
}
|
||||||
|
if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") {
|
||||||
|
t.Fatalf("expected prefixed customtools model to exist in fallback catalog")
|
||||||
|
}
|
||||||
|
if HasFallbackModel("gemini-unknown") {
|
||||||
|
t.Fatalf("did not expect unknown model to exist in fallback catalog")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -515,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeRequestedModelForLookup(platform, requestedModel string) string {
|
||||||
|
trimmed := strings.TrimSpace(requestedModel)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if platform != PlatformGemini && platform != PlatformAntigravity {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
if trimmed == "gemini-3.1-pro-preview-customtools" {
|
||||||
|
return "gemini-3.1-pro-preview"
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool {
|
||||||
|
if requestedModel == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, exists := mapping[requestedModel]; exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for pattern := range mapping {
|
||||||
|
if matchWildcard(pattern, requestedModel) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) {
|
||||||
|
if requestedModel == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
|
return mappedModel, true
|
||||||
|
}
|
||||||
|
return matchWildcardMappingResult(mapping, requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
@ -522,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return true // 无映射 = 允许所有
|
return true // 无映射 = 允许所有
|
||||||
}
|
}
|
||||||
// 精确匹配
|
if mappingSupportsRequestedModel(mapping, requestedModel) {
|
||||||
if _, exists := mapping[requestedModel]; exists {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// 通配符匹配
|
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
|
||||||
for pattern := range mapping {
|
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
||||||
if matchWildcard(pattern, requestedModel) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||||
@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
|
|||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel, false
|
return requestedModel, false
|
||||||
}
|
}
|
||||||
// 精确匹配优先
|
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
|
||||||
return mappedModel, true
|
return mappedModel, true
|
||||||
}
|
}
|
||||||
// 通配符匹配(最长优先)
|
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
|
||||||
return matchWildcardMappingResult(mapping, requestedModel)
|
if normalized != requestedModel {
|
||||||
|
if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched {
|
||||||
|
return mappedModel, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestedModel, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) GetBaseURL() string {
|
func (a *Account) GetBaseURL() string {
|
||||||
|
|||||||
@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) {
|
|||||||
func TestAccountIsModelSupported(t *testing.T) {
|
func TestAccountIsModelSupported(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
platform string
|
||||||
credentials map[string]any
|
credentials map[string]any
|
||||||
requestedModel string
|
requestedModel string
|
||||||
expected bool
|
expected bool
|
||||||
@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) {
|
|||||||
requestedModel: "claude-opus-4-5-thinking",
|
requestedModel: "claude-opus-4-5-thinking",
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gemini customtools alias matches normalized mapping",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "wildcard match not supported",
|
name: "wildcard match not supported",
|
||||||
credentials: map[string]any{
|
credentials: map[string]any{
|
||||||
@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
|
Platform: tt.platform,
|
||||||
Credentials: tt.credentials,
|
Credentials: tt.credentials,
|
||||||
}
|
}
|
||||||
result := account.IsModelSupported(tt.requestedModel)
|
result := account.IsModelSupported(tt.requestedModel)
|
||||||
@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) {
|
|||||||
func TestAccountGetMappedModel(t *testing.T) {
|
func TestAccountGetMappedModel(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
platform string
|
||||||
credentials map[string]any
|
credentials map[string]any
|
||||||
requestedModel string
|
requestedModel string
|
||||||
expected string
|
expected string
|
||||||
@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
requestedModel: "claude-sonnet-4-5",
|
requestedModel: "claude-sonnet-4-5",
|
||||||
expected: "claude-sonnet-4-5",
|
expected: "claude-sonnet-4-5",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "no mapping preserves gemini customtools model",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: nil,
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expected: "gemini-3.1-pro-preview-customtools",
|
||||||
|
},
|
||||||
|
|
||||||
// 精确匹配
|
// 精确匹配
|
||||||
{
|
{
|
||||||
@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// 无匹配返回原始模型
|
// 无匹配返回原始模型
|
||||||
|
{
|
||||||
|
name: "gemini customtools alias resolves through normalized mapping",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expected: "gemini-3.1-pro-preview",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini customtools exact mapping wins over normalized fallback",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||||
|
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expected: "gemini-3.1-pro-preview-customtools",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no match returns original",
|
name: "no match returns original",
|
||||||
credentials: map[string]any{
|
credentials: map[string]any{
|
||||||
@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
|
Platform: tt.platform,
|
||||||
Credentials: tt.credentials,
|
Credentials: tt.credentials,
|
||||||
}
|
}
|
||||||
result := account.GetMappedModel(tt.requestedModel)
|
result := account.GetMappedModel(tt.requestedModel)
|
||||||
@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
func TestAccountResolveMappedModel(t *testing.T) {
|
func TestAccountResolveMappedModel(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
platform string
|
||||||
credentials map[string]any
|
credentials map[string]any
|
||||||
requestedModel string
|
requestedModel string
|
||||||
expectedModel string
|
expectedModel string
|
||||||
@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) {
|
|||||||
expectedModel: "gpt-5.4",
|
expectedModel: "gpt-5.4",
|
||||||
expectedMatch: true,
|
expectedMatch: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gemini customtools alias reports normalized match",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expectedModel: "gemini-3.1-pro-preview",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemini customtools exact mapping reports exact match",
|
||||||
|
platform: PlatformGemini,
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||||
|
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expectedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "missing mapping reports unmatched",
|
name: "missing mapping reports unmatched",
|
||||||
credentials: map[string]any{
|
credentials: map[string]any{
|
||||||
@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
|
Platform: tt.platform,
|
||||||
Credentials: tt.credentials,
|
Credentials: tt.credentials,
|
||||||
}
|
}
|
||||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||||
|
|||||||
@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
|
|||||||
requestedModel: "gemini-2.5-flash",
|
requestedModel: "gemini-2.5-flash",
|
||||||
expected: "gemini-2.5-flash",
|
expected: "gemini-2.5-flash",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "customtools alias falls back to normalized preview mapping",
|
||||||
|
modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"},
|
||||||
|
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||||
|
expected: "gemini-3.1-pro-high",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
if v, ok := reqBody["model"].(string); ok {
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
model = v
|
model = v
|
||||||
}
|
}
|
||||||
normalizedModel := normalizeCodexModel(model)
|
normalizedModel := strings.TrimSpace(model)
|
||||||
if normalizedModel != "" {
|
if normalizedModel != "" {
|
||||||
if model != normalizedModel {
|
if model != normalizedModel {
|
||||||
reqBody["model"] = normalizedModel
|
reqBody["model"] = normalizedModel
|
||||||
|
|||||||
@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
|||||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||||
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
"gpt-5.3-codex-spark": "gpt-5.3-codex",
|
||||||
|
"gpt 5.3 codex spark": "gpt-5.3-codex",
|
||||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||||
"gpt 5.3 codex": "gpt-5.3-codex",
|
"gpt 5.3 codex": "gpt-5.3-codex",
|
||||||
@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": "gpt-5.3-codex-spark",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
|
||||||
|
require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
|
||||||
|
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
|
||||||
|
store, ok := reqBody["store"].(bool)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.False(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"model": " gpt-5.3-codex-spark ",
|
||||||
|
"input": []any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
|
||||||
|
require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"])
|
||||||
|
require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel)
|
||||||
|
require.True(t, result.Modified)
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||||
// Codex CLI 场景:已有 instructions 时不修改
|
// Codex CLI 场景:已有 instructions 时不修改
|
||||||
|
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import (
|
|||||||
const compatPromptCacheKeyPrefix = "compat_cc_"
|
const compatPromptCacheKeyPrefix = "compat_cc_"
|
||||||
|
|
||||||
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
|
||||||
switch normalizeCodexModel(strings.TrimSpace(model)) {
|
switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) {
|
||||||
case "gpt-5.4", "gpt-5.3-codex":
|
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel))
|
normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))
|
||||||
if normalizedModel == "" {
|
if normalizedModel == "" {
|
||||||
normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model))
|
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))
|
||||||
}
|
}
|
||||||
if normalizedModel == "" {
|
if normalizedModel == "" {
|
||||||
normalizedModel = strings.TrimSpace(req.Model)
|
normalizedModel = strings.TrimSpace(req.Model)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) {
|
|||||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4"))
|
||||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3"))
|
||||||
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex"))
|
||||||
|
require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark"))
|
||||||
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
|
require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
|
|||||||
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
|
k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4")
|
||||||
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
|
require.NotEqual(t, k1, k2, "different first user messages should yield different keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
|
||||||
|
req := &apicompat.ChatCompletionsRequest{
|
||||||
|
Model: "gpt-5.3-codex-spark",
|
||||||
|
Messages: []apicompat.ChatMessage{
|
||||||
|
{Role: "user", Content: mustRawJSON(t, `"Question A"`)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark")
|
||||||
|
k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ")
|
||||||
|
require.NotEmpty(t, k1)
|
||||||
|
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
|
||||||
|
}
|
||||||
|
|||||||
@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
|
|
||||||
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
// 2. Resolve model mapping early so compat prompt_cache_key injection can
|
||||||
// derive a stable seed from the final upstream model family.
|
// derive a stable seed from the final upstream model family.
|
||||||
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
|
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
|
||||||
|
|
||||||
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
promptCacheKey = strings.TrimSpace(promptCacheKey)
|
||||||
compatPromptCacheInjected := false
|
compatPromptCacheInjected := false
|
||||||
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) {
|
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
|
||||||
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel)
|
promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel)
|
||||||
compatPromptCacheInjected = promptCacheKey != ""
|
compatPromptCacheInjected = promptCacheKey != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
}
|
}
|
||||||
responsesReq.Model = mappedModel
|
responsesReq.Model = upstreamModel
|
||||||
|
|
||||||
logFields := []zap.Field{
|
logFields := []zap.Field{
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("original_model", originalModel),
|
zap.String("original_model", originalModel),
|
||||||
zap.String("mapped_model", mappedModel),
|
zap.String("billing_model", billingModel),
|
||||||
|
zap.String("upstream_model", upstreamModel),
|
||||||
zap.Bool("stream", clientStream),
|
zap.Bool("stream", clientStream),
|
||||||
}
|
}
|
||||||
if compatPromptCacheInjected {
|
if compatPromptCacheInjected {
|
||||||
@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||||
}
|
}
|
||||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
if codexResult.NormalizedModel != "" {
|
||||||
|
upstreamModel = codexResult.NormalizedModel
|
||||||
|
}
|
||||||
if codexResult.PromptCacheKey != "" {
|
if codexResult.PromptCacheKey != "" {
|
||||||
promptCacheKey = codexResult.PromptCacheKey
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
} else if promptCacheKey != "" {
|
} else if promptCacheKey != "" {
|
||||||
@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
var result *OpenAIForwardResult
|
var result *OpenAIForwardResult
|
||||||
var handleErr error
|
var handleErr error
|
||||||
if clientStream {
|
if clientStream {
|
||||||
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
|
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
|
||||||
} else {
|
} else {
|
||||||
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||||
@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
originalModel string,
|
originalModel string,
|
||||||
mappedModel string,
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: usage,
|
Usage: usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
BillingModel: mappedModel,
|
BillingModel: billingModel,
|
||||||
UpstreamModel: mappedModel,
|
UpstreamModel: upstreamModel,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
}, nil
|
}, nil
|
||||||
@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
originalModel string,
|
originalModel string,
|
||||||
mappedModel string,
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
includeUsage bool,
|
includeUsage bool,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: usage,
|
Usage: usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
BillingModel: mappedModel,
|
BillingModel: billingModel,
|
||||||
UpstreamModel: mappedModel,
|
UpstreamModel: upstreamModel,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
originalModel := anthropicReq.Model
|
originalModel := anthropicReq.Model
|
||||||
applyOpenAICompatModelNormalization(&anthropicReq)
|
applyOpenAICompatModelNormalization(&anthropicReq)
|
||||||
|
normalizedModel := anthropicReq.Model
|
||||||
clientStream := anthropicReq.Stream // client's original stream preference
|
clientStream := anthropicReq.Stream // client's original stream preference
|
||||||
|
|
||||||
// 2. Convert Anthropic → Responses
|
// 2. Convert Anthropic → Responses
|
||||||
@ -60,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Model mapping
|
// 3. Model mapping
|
||||||
mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel)
|
billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel)
|
||||||
responsesReq.Model = mappedModel
|
upstreamModel := resolveOpenAIUpstreamModel(billingModel)
|
||||||
|
responsesReq.Model = upstreamModel
|
||||||
|
|
||||||
logger.L().Debug("openai messages: model mapping applied",
|
logger.L().Debug("openai messages: model mapping applied",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.String("original_model", originalModel),
|
zap.String("original_model", originalModel),
|
||||||
zap.String("mapped_model", mappedModel),
|
zap.String("normalized_model", normalizedModel),
|
||||||
|
zap.String("billing_model", billingModel),
|
||||||
|
zap.String("upstream_model", upstreamModel),
|
||||||
zap.Bool("stream", isStream),
|
zap.Bool("stream", isStream),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||||
}
|
}
|
||||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||||
|
if codexResult.NormalizedModel != "" {
|
||||||
|
upstreamModel = codexResult.NormalizedModel
|
||||||
|
}
|
||||||
if codexResult.PromptCacheKey != "" {
|
if codexResult.PromptCacheKey != "" {
|
||||||
promptCacheKey = codexResult.PromptCacheKey
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
} else if promptCacheKey != "" {
|
} else if promptCacheKey != "" {
|
||||||
@ -182,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
var result *OpenAIForwardResult
|
var result *OpenAIForwardResult
|
||||||
var handleErr error
|
var handleErr error
|
||||||
if clientStream {
|
if clientStream {
|
||||||
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
||||||
} else {
|
} else {
|
||||||
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
|
||||||
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate ServiceTier and ReasoningEffort to result for billing
|
// Propagate ServiceTier and ReasoningEffort to result for billing
|
||||||
@ -230,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
originalModel string,
|
originalModel string,
|
||||||
mappedModel string,
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
@ -303,8 +311,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: usage,
|
Usage: usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
BillingModel: mappedModel,
|
BillingModel: billingModel,
|
||||||
UpstreamModel: mappedModel,
|
UpstreamModel: upstreamModel,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
}, nil
|
}, nil
|
||||||
@ -319,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
originalModel string,
|
originalModel string,
|
||||||
mappedModel string,
|
billingModel string,
|
||||||
|
upstreamModel string,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
@ -352,8 +361,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: usage,
|
Usage: usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
BillingModel: mappedModel,
|
BillingModel: billingModel,
|
||||||
UpstreamModel: mappedModel,
|
UpstreamModel: upstreamModel,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
|||||||
@ -1818,29 +1818,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
billingModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if billingModel != reqModel {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI)
|
||||||
reqBody["model"] = mappedModel
|
reqBody["model"] = billingModel
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
markPatchSet("model", mappedModel)
|
markPatchSet("model", billingModel)
|
||||||
}
|
}
|
||||||
|
upstreamModel := billingModel
|
||||||
|
|
||||||
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||||
if model, ok := reqBody["model"].(string); ok {
|
if model, ok := reqBody["model"].(string); ok {
|
||||||
normalizedModel := normalizeCodexModel(model)
|
upstreamModel = resolveOpenAIUpstreamModel(model)
|
||||||
if normalizedModel != "" && normalizedModel != model {
|
if upstreamModel != "" && upstreamModel != model {
|
||||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||||
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
model, upstreamModel, account.Name, account.Type, isCodexCLI)
|
||||||
reqBody["model"] = normalizedModel
|
reqBody["model"] = upstreamModel
|
||||||
mappedModel = normalizedModel
|
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
markPatchSet("model", normalizedModel)
|
markPatchSet("model", upstreamModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
|
||||||
// 确保高版本模型向低版本模型映射不报错
|
// 确保高版本模型向低版本模型映射不报错
|
||||||
if !SupportsVerbosity(normalizedModel) {
|
if !SupportsVerbosity(upstreamModel) {
|
||||||
if text, ok := reqBody["text"].(map[string]any); ok {
|
if text, ok := reqBody["text"].(map[string]any); ok {
|
||||||
delete(text, "verbosity")
|
delete(text, "verbosity")
|
||||||
}
|
}
|
||||||
@ -1864,7 +1864,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
disablePatch()
|
disablePatch()
|
||||||
}
|
}
|
||||||
if codexResult.NormalizedModel != "" {
|
if codexResult.NormalizedModel != "" {
|
||||||
mappedModel = codexResult.NormalizedModel
|
upstreamModel = codexResult.NormalizedModel
|
||||||
}
|
}
|
||||||
if codexResult.PromptCacheKey != "" {
|
if codexResult.PromptCacheKey != "" {
|
||||||
promptCacheKey = codexResult.PromptCacheKey
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
@ -1981,7 +1981,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
"forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
|
"forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
|
||||||
account.ID,
|
account.ID,
|
||||||
account.Type,
|
account.Type,
|
||||||
mappedModel,
|
upstreamModel,
|
||||||
reqStream,
|
reqStream,
|
||||||
hasPreviousResponseID,
|
hasPreviousResponseID,
|
||||||
)
|
)
|
||||||
@ -2070,7 +2070,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
isCodexCLI,
|
isCodexCLI,
|
||||||
reqStream,
|
reqStream,
|
||||||
originalModel,
|
originalModel,
|
||||||
mappedModel,
|
upstreamModel,
|
||||||
startTime,
|
startTime,
|
||||||
attempt,
|
attempt,
|
||||||
wsLastFailureReason,
|
wsLastFailureReason,
|
||||||
@ -2171,7 +2171,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
firstTokenMs,
|
firstTokenMs,
|
||||||
wsAttempts,
|
wsAttempts,
|
||||||
)
|
)
|
||||||
wsResult.UpstreamModel = mappedModel
|
wsResult.UpstreamModel = upstreamModel
|
||||||
return wsResult, nil
|
return wsResult, nil
|
||||||
}
|
}
|
||||||
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
|
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
|
||||||
@ -2276,14 +2276,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
var usage *OpenAIUsage
|
var usage *OpenAIUsage
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
if reqStream {
|
if reqStream {
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
usage = streamResult.usage
|
usage = streamResult.usage
|
||||||
firstTokenMs = streamResult.firstTokenMs
|
firstTokenMs = streamResult.firstTokenMs
|
||||||
} else {
|
} else {
|
||||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -2307,7 +2307,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
RequestID: resp.Header.Get("x-request-id"),
|
RequestID: resp.Header.Get("x-request-id"),
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
UpstreamModel: mappedModel,
|
UpstreamModel: upstreamModel,
|
||||||
ServiceTier: serviceTier,
|
ServiceTier: serviceTier,
|
||||||
ReasoningEffort: reasoningEffort,
|
ReasoningEffort: reasoningEffort,
|
||||||
Stream: reqStream,
|
Stream: reqStream,
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
import "strings"
|
||||||
// forwarding. Group-level default mapping only applies when the account itself
|
|
||||||
// did not match any explicit model_mapping rule.
|
// resolveOpenAIForwardModel resolves the account/group mapping result for
|
||||||
|
// OpenAI-compatible forwarding. Group-level default mapping only applies when
|
||||||
|
// the account itself did not match any explicit model_mapping rule.
|
||||||
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||||
if account == nil {
|
if account == nil {
|
||||||
if defaultMappedModel != "" {
|
if defaultMappedModel != "" {
|
||||||
@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
|
|||||||
}
|
}
|
||||||
return mappedModel
|
return mappedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveOpenAIUpstreamModel(model string) string {
|
||||||
|
if isBareGPT53CodexSparkModel(model) {
|
||||||
|
return "gpt-5.3-codex-spark"
|
||||||
|
}
|
||||||
|
return normalizeCodexModel(strings.TrimSpace(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBareGPT53CodexSparkModel(model string) bool {
|
||||||
|
modelID := strings.TrimSpace(model)
|
||||||
|
if modelID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(modelID, "/") {
|
||||||
|
parts := strings.Split(modelID, "/")
|
||||||
|
modelID = parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(modelID))
|
||||||
|
return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark"
|
||||||
|
}
|
||||||
|
|||||||
@ -74,13 +74,30 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
|
|||||||
Credentials: map[string]any{},
|
Credentials: map[string]any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "")
|
withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
|
||||||
if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" {
|
if withoutDefault != "gpt-5.1" {
|
||||||
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1")
|
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")
|
withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
|
||||||
if got := normalizeCodexModel(withDefault); got != "gpt-5.4" {
|
if withDefault != "gpt-5.4" {
|
||||||
t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4")
|
t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveOpenAIUpstreamModel(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
|
||||||
|
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
|
||||||
|
" openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark",
|
||||||
|
"gpt-5.3-codex-spark-high": "gpt-5.3-codex",
|
||||||
|
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
|
||||||
|
"gpt-5.3": "gpt-5.3-codex",
|
||||||
|
}
|
||||||
|
|
||||||
|
for input, expected := range cases {
|
||||||
|
if got := resolveOpenAIUpstreamModel(input); got != expected {
|
||||||
|
t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
}
|
}
|
||||||
normalized = next
|
normalized = next
|
||||||
}
|
}
|
||||||
mappedModel := account.GetMappedModel(originalModel)
|
upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
|
||||||
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
|
if upstreamModel != originalModel {
|
||||||
mappedModel = normalizedModel
|
next, setErr := applyPayloadMutation(normalized, "model", upstreamModel)
|
||||||
}
|
|
||||||
if mappedModel != originalModel {
|
|
||||||
next, setErr := applyPayloadMutation(normalized, "model", mappedModel)
|
|
||||||
if setErr != nil {
|
if setErr != nil {
|
||||||
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
|
||||||
}
|
}
|
||||||
@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
mappedModel := ""
|
mappedModel := ""
|
||||||
var mappedModelBytes []byte
|
var mappedModelBytes []byte
|
||||||
if originalModel != "" {
|
if originalModel != "" {
|
||||||
mappedModel = account.GetMappedModel(originalModel)
|
mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel))
|
||||||
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
|
|
||||||
mappedModel = normalizedModel
|
|
||||||
}
|
|
||||||
needModelReplace = mappedModel != "" && mappedModel != originalModel
|
needModelReplace = mappedModel != "" && mappedModel != originalModel
|
||||||
if needModelReplace {
|
if needModelReplace {
|
||||||
mappedModelBytes = []byte(mappedModel)
|
mappedModelBytes = []byte(mappedModel)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user