diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 17f0d47e..9805bf8a 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -127,7 +127,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { for { reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", @@ -135,6 +135,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go index bbb67044..81713f7f 100644 --- a/backend/internal/handler/openai_embeddings.go +++ b/backend/internal/handler/openai_embeddings.go @@ -107,7 +107,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { routingStart := time.Now() for { - selection, _, err := h.gatewayService.SelectAccountWithScheduler( + selection, _, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", @@ -115,6 +115,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportHTTPSSE, + service.OpenAIEndpointCapabilityEmbeddings, false, ) if err != nil { @@ -140,13 +141,6 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { return } account := selection.Account - if account.Type != service.AccountTypeAPIKey { - if selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - failedAccountIDs[account.ID] = struct{}{} - continue - } setOpsSelectedAccount(c, account.ID, account.Platform) accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a51eee86..1d661748 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -266,7 +266,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, previousResponseID, @@ -274,6 +274,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, requireCompact, ) if err != nil { @@ -675,7 +676,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { currentRoutingModel = effectiveMappedModel } reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", // no previous_response_id @@ -683,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { currentRoutingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { @@ -1273,7 +1275,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { for { reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( ctx, apiKey.GroupID, previousResponseID, @@ -1281,6 +1283,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportResponsesWebsocketV2, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index d488aa75..e3ca9c5d 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -66,6 +66,15 @@ type Account struct { modelMappingCacheRawSig uint64 } +type OpenAIEndpointCapability string + +const ( + OpenAIEndpointCapabilityChatCompletions OpenAIEndpointCapability = "chat_completions" + OpenAIEndpointCapabilityEmbeddings OpenAIEndpointCapability = "embeddings" +) + +const openAIEndpointCapabilitiesCredentialKey = "openai_capabilities" + type TempUnschedulableRule struct { ErrorCode int `json:"error_code"` Keywords []string `json:"keywords"` @@ -1122,6 +1131,80 @@ func (a *Account) GetOpenAISessionID() string { return strings.TrimSpace(a.GetExtraString("openai_session_id")) } +func (a *Account) SupportsOpenAIEndpointCapability(capability OpenAIEndpointCapability) bool { + if a == nil { + return false + } + if capability == "" { + return true + } + if !a.IsOpenAI() { + return false + } + switch capability { + case OpenAIEndpointCapabilityChatCompletions: + case OpenAIEndpointCapabilityEmbeddings: + if a.Type != AccountTypeAPIKey { + return false + } + default: + return false + } + + configured, found := a.openAIEndpointCapabilitySet() + if !found { + return true + } + return configured[string(capability)] +} + +func (a *Account) openAIEndpointCapabilitySet() (map[string]bool, bool) { + if a == nil || a.Credentials == nil { + return nil, false + } + raw, found := a.Credentials[openAIEndpointCapabilitiesCredentialKey] + if !found || raw == nil { + return nil, false + } + + result := make(map[string]bool) + add := func(value string) { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return + } + result[value] = true + } + + switch capabilities := raw.(type) { + case []any: + for _, item := range capabilities { + if value, ok := item.(string); ok { + add(value) + } + } + case []string: + for _, value := range capabilities { + add(value) + } + case map[string]any: + for key, value := range capabilities { + enabled, ok := value.(bool) + if ok && enabled { + add(key) + } + } + case map[string]bool: + for key, enabled := range capabilities { + if enabled { + add(key) + } + } + } + + return result, true +} + func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool { if !a.IsOpenAI() { return false diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index a8ac391a..1eca08b1 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -44,6 +44,7 @@ type OpenAIAccountScheduleRequest struct { PreviousResponseID string RequestedModel string RequiredTransport OpenAIUpstreamTransport + RequiredCapability OpenAIEndpointCapability RequiredImageCapability OpenAIImagesCapability RequireCompact bool ExcludedIDs map[int64]struct{} @@ -263,12 +264,13 @@ func (s *defaultOpenAIAccountScheduler) Select( previousResponseID := strings.TrimSpace(req.PreviousResponseID) if previousResponseID != "" { - selection, err := s.service.SelectAccountByPreviousResponseID( + selection, err := s.service.selectAccountByPreviousResponseIDForCapability( ctx, req.GroupID, previousResponseID, req.RequestedModel, req.ExcludedIDs, + req.RequiredCapability, req.RequireCompact, ) if err != nil { @@ -363,7 +365,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact) + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact, req.RequiredCapability) if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil @@ -791,11 +793,11 @@ func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder( compactBlocked := false for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } @@ -930,11 +932,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( cfg := s.service.schedulingConfig() // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } @@ -981,7 +983,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) { return false } - return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) + return accountSupportsOpenAICapabilities(account, req.RequiredCapability, req.RequiredImageCapability) } func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { @@ -1104,7 +1106,21 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact) + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", "", requireCompact) +} + +func (s *OpenAIGatewayService) SelectAccountWithSchedulerForCapability( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, + requiredCapability OpenAIEndpointCapability, + requireCompact bool, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, requiredCapability, "", requireCompact) } func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( @@ -1115,13 +1131,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( excludedIDs map[int64]struct{}, requiredCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false) + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", requiredCapability, false) if err == nil && selection != nil && selection.Account != nil { return selection, decision, nil } // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) if requiredCapability == OpenAIImagesCapabilityNative { - return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false) + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", OpenAIImagesCapabilityBasic, false) } return selection, decision, err } @@ -1134,6 +1150,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( requestedModel string, excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, + requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability, requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { @@ -1144,14 +1161,14 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability) if err != nil { return nil, decision, err } if selection == nil || selection.Account == nil { return selection, decision, nil } - if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) { + if accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) { return selection, decision, nil } if selection.ReleaseFunc != nil { @@ -1169,14 +1186,15 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability) if err != nil { return nil, decision, err } if selection == nil || selection.Account == nil { return selection, decision, nil } - if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) { + if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) && + accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) { return selection, decision, nil } if selection.ReleaseFunc != nil { @@ -1213,12 +1231,21 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( PreviousResponseID: previousResponseID, RequestedModel: requestedModel, RequiredTransport: requiredTransport, + RequiredCapability: requiredCapability, RequiredImageCapability: requiredImageCapability, RequireCompact: requireCompact, ExcludedIDs: excludedIDs, }) } +func accountSupportsOpenAICapabilities(account *Account, requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability) bool { + if account == nil { + return false + } + return account.SupportsOpenAIEndpointCapability(requiredCapability) && + account.SupportsOpenAIImageCapability(requiredImageCapability) +} + func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} { if len(excludedIDs) == 0 { return nil diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index ba20ee5f..fedf7e9c 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -393,6 +393,64 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10110) + accounts := []Account{ + { + ID: 36031, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + }, + { + ID: 36032, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "", + "", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36032), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { resetOpenAIAdvancedSchedulerSettingCacheForTest() @@ -458,6 +516,141 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev require.True(t, decision.StickyPreviousHit) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10111) + accounts := []Account{ + { + ID: 37011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + }, + { + ID: 37012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "", + "", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37012), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 1, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyStickyBindings(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10112) + accounts := []Account{ + { + ID: 37021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37022, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_embeddings": 37021, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_embeddings_chat_only", 37021, time.Hour)) + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "resp_embeddings_chat_only", + "session_hash_embeddings", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37022), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) + require.False(t, decision.StickySessionHit) + require.Equal(t, int64(37022), cache.sessionBindings["openai:session_hash_embeddings"]) +} + func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { resetOpenAIAdvancedSchedulerSettingCacheForTest() diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f93cc221..77587f69 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1279,7 +1279,7 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0) + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0, "") } // noAvailableOpenAISelectionError builds the standard "no account available" error @@ -1312,13 +1312,16 @@ func openAICompactSupportTier(account *Account) int { // isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / // compact-support checks used during account selection. -func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool { +func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) bool { if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) { return false } if requestedModel != "" && !account.IsModelSupported(requestedModel) { return false } + if !account.SupportsOpenAIEndpointCapability(requiredCapability) { + return false + } if requireCompact && openAICompactSupportTier(account) == 0 { return false } @@ -1366,7 +1369,7 @@ func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedMode return upstreamModel } -func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) { +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) (*Account, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1376,7 +1379,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability); account != nil { return account, nil } @@ -1389,7 +1392,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact) + selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact, requiredCapability) if selected == nil { return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) @@ -1414,7 +1417,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) *Account { if sessionHash == "" { return nil } @@ -1446,14 +1449,14 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 验证账号是否可用于当前请求 // Verify account is usable for current request - if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(account) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil @@ -1477,7 +1480,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // Returns nil if no available account. The second return reports whether at // least one candidate was filtered out solely because it lacks compact support // (only meaningful when requireCompact=true). -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*Account, bool) { var selected *Account selectedCompactTier := -1 compactBlocked := false @@ -1492,11 +1495,11 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i continue } - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false, requiredCapability) if fresh == nil { continue } @@ -1573,10 +1576,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false) + return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, "") } -func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) { +func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*AccountSelectionResult, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1593,7 +1596,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability) if err != nil { return nil, err } @@ -1646,8 +1649,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex if clearSticky { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } - if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) { + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else if s.isOpenAIAccountRuntimeBlocked(account) { @@ -1691,15 +1694,12 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !isOpenAIAccountEligibleForRequest(ctx, acc, requestedModel, false, requiredCapability) { continue } if s.isOpenAIAccountRuntimeBlocked(acc) { continue } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue - } if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) { continue } @@ -1779,11 +1779,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex } for _, item := range selectionOrder { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1813,11 +1813,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex ordered = prioritizeOpenAICompactAccounts(ordered) } for _, acc := range ordered { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1858,11 +1858,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex candidates = prioritizeOpenAICompactAccounts(candidates) } for _, acc := range candidates { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1910,7 +1910,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account { if account == nil { return nil } @@ -1924,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. fresh = current } - if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(fresh) { @@ -1933,12 +1933,12 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. return fresh } -func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account { if account == nil { return nil } if s.schedulerSnapshot == nil || s.accountRepo == nil { - if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact, requiredCapability) { return nil } return account @@ -1948,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if err != nil || latest == nil { return nil } - if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(latest) { diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 854e9f6d..a87e96c1 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -413,6 +413,79 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) } +func TestAccountSupportsOpenAIEndpointCapability(t *testing.T) { + t.Run("OpenAI APIKey 默认兼容 chat 和 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("OpenAI OAuth 默认仅兼容 chat", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式列表支持同时声明 chat 和 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式列表只声明 chat 时不支持 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式 map 支持单独关闭 chat 并开启 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": map[string]any{ + "chat_completions": false, + "embeddings": true, + }, + }, + } + + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("未知能力不应默认放行", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapability("unknown"))) + }) +} + func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) { require.Equal(t, "https://image-upstream.example/v1/images/generations", diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 4005a921..c8b28a46 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -268,6 +268,52 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky( require.Equal(t, int64(21), selection.WaitPlan.AccountID) } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_CapabilityMismatchKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(25) + account := Account{ + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_capability", account.ID, time.Hour)) + + selection, err := svc.selectAccountByPreviousResponseIDForCapability( + ctx, + &groupID, + "resp_prev_capability", + "text-embedding-3-small", + nil, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.Nil(t, selection) + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_capability") + require.NoError(t, getErr) + require.Equal(t, account.ID, boundAccountID) +} + func newOpenAIWSV2TestConfig() *config.Config { cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index b8e558ae..5fd5cffc 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -3987,6 +3987,18 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, +) (*AccountSelectionResult, error) { + return s.selectAccountByPreviousResponseIDForCapability(ctx, groupID, previousResponseID, requestedModel, excludedIDs, "", requireCompact) +} + +func (s *OpenAIGatewayService) selectAccountByPreviousResponseIDForCapability( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredCapability OpenAIEndpointCapability, + requireCompact bool, ) (*AccountSelectionResult, error) { if s == nil { return nil, nil @@ -4027,12 +4039,31 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) - if account == nil { - _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + if !account.SupportsOpenAIEndpointCapability(requiredCapability) { return nil, nil } - // 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。 + if s.schedulerSnapshot != nil && s.accountRepo != nil { + latest, latestErr := s.accountRepo.GetByID(ctx, account.ID) + if latestErr != nil || latest == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if shouldClearStickySession(latest, requestedModel) || !latest.IsOpenAI() || !latest.IsSchedulable() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + return nil, nil + } + if !latest.SupportsOpenAIEndpointCapability(requiredCapability) { + return nil, nil + } + if s.isOpenAIAccountRuntimeBlocked(latest) { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + account = latest + } if requireCompact && openAICompactSupportTier(account) == 0 { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 331295f7..665c4695 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2679,7 +2679,7 @@
@@ -2696,6 +2696,26 @@ />
+
+ +
+ +
+

{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}

+
@@ -3172,7 +3192,8 @@ import type { CreateAccountRequest, CodexSessionImportMessage, OpenAICompactMode, - OpenAIResponsesMode + OpenAIResponsesMode, + OpenAIEndpointCapability } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue' @@ -3350,6 +3371,7 @@ const autoPauseOnExpired = ref(true) const openaiPassthroughEnabled = ref(false) const openAICompactMode = ref('auto') const openAIResponsesMode = ref('auto') +const openAIEndpointCapabilities = ref(['chat_completions', 'embeddings']) const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) @@ -3412,6 +3434,43 @@ const openAIResponsesModeOptions = computed(() => [ { value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') }, { value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') } ]) +const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [ + { value: 'chat_completions', label: t('admin.accounts.openai.capabilityChatCompletions') }, + { value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') } +]) + +const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => { + const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings'] + const selected = allowed.filter((value) => values.includes(value)) + return selected.length > 0 ? selected : allowed +} + +const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => { + if (openAIEndpointCapabilities.value.includes(capability)) { + if (openAIEndpointCapabilities.value.length <= 1) { + const input = event?.target as HTMLInputElement | null + if (input) input.checked = true + return + } + openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter( + (value) => value !== capability + ) + return + } + openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([ + ...openAIEndpointCapabilities.value, + capability + ]) +} + +const applyOpenAIEndpointCapabilities = (credentials: Record) => { + const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value) + if (capabilities.length === 2) { + delete credentials.openai_capabilities + return + } + credentials.openai_capabilities = capabilities +} function buildAntigravityExtra(): Record | undefined { const extra: Record = {} @@ -3721,6 +3780,7 @@ watch( } if (newPlatform !== 'openai') { openaiPassthroughEnabled.value = false + openAIEndpointCapabilities.value = ['chat_completions', 'embeddings'] openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false @@ -4120,6 +4180,7 @@ const resetForm = () => { openaiPassthroughEnabled.value = false openAICompactMode.value = 'auto' openAIResponsesMode.value = 'auto' + openAIEndpointCapabilities.value = ['chat_completions', 'embeddings'] openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false @@ -4498,6 +4559,7 @@ const handleSubmit = async () => { } } if (form.platform === 'openai') { + applyOpenAIEndpointCapabilities(credentials) const compactModelMapping = buildOpenAICompactModelMapping() if (compactModelMapping) { credentials.compact_model_mapping = compactModelMapping @@ -4620,6 +4682,9 @@ const createAccountAndFinish = async ( } } if (platform === 'openai') { + if (type === 'apikey') { + applyOpenAIEndpointCapabilities(credentials) + } const compactModelMapping = buildOpenAICompactModelMapping() if (compactModelMapping) { credentials.compact_model_mapping = compactModelMapping diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index f44b5d38..3cb10591 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1439,7 +1439,7 @@
@@ -1459,6 +1459,26 @@
{{ t(openAIResponsesStatusKey) }}
+
+ +
+ +
+

{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}

+
@@ -2245,7 +2265,15 @@ import { useAppStore } from '@/stores/app' import { useAuthStore } from '@/stores/auth' import { adminAPI } from '@/api/admin' import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState' -import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode, OpenAIResponsesMode } from '@/types' +import type { + Account, + Proxy, + AdminGroup, + CheckMixedChannelResponse, + OpenAICompactMode, + OpenAIResponsesMode, + OpenAIEndpointCapability +} from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' @@ -2433,6 +2461,7 @@ const customBaseUrl = ref('') const openaiPassthroughEnabled = ref(false) const openAICompactMode = ref('auto') const openAIResponsesMode = ref('auto') +const openAIEndpointCapabilities = ref(['chat_completions', 'embeddings']) const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) @@ -2539,6 +2568,63 @@ const openAIResponsesModeOptions = computed(() => [ { value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') }, { value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') } ]) +const openAIEndpointCapabilityOptions = computed<{ value: OpenAIEndpointCapability; label: string }[]>(() => [ + { value: 'chat_completions', label: t('admin.accounts.openai.capabilityChatCompletions') }, + { value: 'embeddings', label: t('admin.accounts.openai.capabilityEmbeddings') } +]) + +const normalizeOpenAIEndpointCapabilities = (values: OpenAIEndpointCapability[]) => { + const allowed: OpenAIEndpointCapability[] = ['chat_completions', 'embeddings'] + const selected = allowed.filter((value) => values.includes(value)) + return selected.length > 0 ? selected : allowed +} + +const readOpenAIEndpointCapabilities = (credentials?: Record): OpenAIEndpointCapability[] => { + const raw = credentials?.openai_capabilities + if (Array.isArray(raw)) { + return normalizeOpenAIEndpointCapabilities( + raw.filter((value): value is OpenAIEndpointCapability => + value === 'chat_completions' || value === 'embeddings' + ) + ) + } + if (raw !== null && typeof raw === 'object') { + const capabilityMap = raw as Record + return normalizeOpenAIEndpointCapabilities( + openAIEndpointCapabilityOptions.value + .map((option) => option.value) + .filter((value) => capabilityMap[value] === true) + ) + } + return ['chat_completions', 'embeddings'] +} + +const toggleOpenAIEndpointCapability = (capability: OpenAIEndpointCapability, event?: Event) => { + if (openAIEndpointCapabilities.value.includes(capability)) { + if (openAIEndpointCapabilities.value.length <= 1) { + const input = event?.target as HTMLInputElement | null + if (input) input.checked = true + return + } + openAIEndpointCapabilities.value = openAIEndpointCapabilities.value.filter( + (value) => value !== capability + ) + return + } + openAIEndpointCapabilities.value = normalizeOpenAIEndpointCapabilities([ + ...openAIEndpointCapabilities.value, + capability + ]) +} + +const applyOpenAIEndpointCapabilities = (credentials: Record) => { + const capabilities = normalizeOpenAIEndpointCapabilities(openAIEndpointCapabilities.value) + if (capabilities.length === 2) { + delete credentials.openai_capabilities + return + } + credentials.openai_capabilities = capabilities +} const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => { if (mode === 'force_responses' || mode === 'force_chat_completions') { return mode @@ -2724,6 +2810,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { openaiPassthroughEnabled.value = false openAICompactMode.value = 'auto' openAIResponsesMode.value = 'auto' + openAIEndpointCapabilities.value = ['chat_completions', 'embeddings'] openAICompactModelMappings.value = [] openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF @@ -2736,6 +2823,9 @@ const syncFormFromAccount = (newAccount: Account | null) => { openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto' if (newAccount.type === 'apikey') { openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode) + openAIEndpointCapabilities.value = readOpenAIEndpointCapabilities( + newAccount.credentials as Record | undefined + ) } const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean' ? extra.codex_image_generation_bridge @@ -3476,6 +3566,7 @@ const handleSubmit = async () => { newCredentials.model_mapping = currentCredentials.model_mapping } if (props.account.platform === 'openai') { + applyOpenAIEndpointCapabilities(newCredentials) const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value) if (compactModelMapping) { newCredentials.compact_model_mapping = compactModelMapping diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts index 0b8e939c..db012a30 100644 --- a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts @@ -310,6 +310,63 @@ describe('EditAccountModal', () => { expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true) }) + it('submits OpenAI APIKey endpoint capabilities from credentials', async () => { + const account = buildAccount() + account.credentials.openai_capabilities = ['chat_completions'] + updateAccountMock.mockReset() + checkMixedChannelRiskMock.mockReset() + checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false }) + updateAccountMock.mockResolvedValue(account) + + const wrapper = mountModal(account) + + expect(wrapper.findAll('input[type="checkbox"]').some((input) => (input.element as HTMLInputElement).checked)).toBe(true) + + await wrapper.get('form#edit-account-form').trigger('submit.prevent') + + expect(updateAccountMock).toHaveBeenCalledTimes(1) + expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([ + 'chat_completions' + ]) + }) + + it('keeps at least one OpenAI APIKey endpoint capability selected', async () => { + const account = buildAccount() + updateAccountMock.mockReset() + checkMixedChannelRiskMock.mockReset() + checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false }) + updateAccountMock.mockResolvedValue(account) + + const wrapper = mountModal(account) + + const chatCheckbox = wrapper.get( + '[data-testid="openai-endpoint-capability-chat_completions"]' + ) + const embeddingsCheckbox = wrapper.get( + '[data-testid="openai-endpoint-capability-embeddings"]' + ) + + expect(chatCheckbox.element.checked).toBe(true) + expect(embeddingsCheckbox.element.checked).toBe(true) + + await embeddingsCheckbox.setValue(false) + + expect(chatCheckbox.element.checked).toBe(true) + expect(embeddingsCheckbox.element.checked).toBe(false) + + await chatCheckbox.setValue(false) + + expect(chatCheckbox.element.checked).toBe(true) + expect(embeddingsCheckbox.element.checked).toBe(false) + + await wrapper.get('form#edit-account-form').trigger('submit.prevent') + + expect(updateAccountMock).toHaveBeenCalledTimes(1) + expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.openai_capabilities).toEqual([ + 'chat_completions' + ]) + }) + it('submits account-level Codex image generation bridge override', async () => { const account = buildAccount() account.extra = { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 41c3c495..ec659dd4 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3353,6 +3353,11 @@ export default { responsesModeAuto: 'Auto', responsesModeForceResponses: 'Force Responses', responsesModeForceChatCompletions: 'Force Chat Completions', + endpointCapabilities: 'Endpoint capabilities', + endpointCapabilitiesDesc: + 'Used by account routing. Both endpoints are allowed by default; if the upstream only supports one, select only the supported endpoint.', + capabilityChatCompletions: 'Chat Completions', + capabilityEmbeddings: 'Embeddings', responsesStatusAutoSupported: 'Auto probe: Responses', responsesStatusAutoUnsupported: 'Auto probe: Chat Completions', responsesStatusAutoUnknown: 'Auto probe: unknown', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 8ff8ea80..36b0d8c3 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3499,6 +3499,11 @@ export default { responsesModeAuto: '自动', responsesModeForceResponses: '强制 Responses', responsesModeForceChatCompletions: '强制 Chat Completions', + endpointCapabilities: '端点能力', + endpointCapabilitiesDesc: + '用于调度筛选。默认两个端点都可用;如果上游只支持其中一个,请只勾选实际支持的端点。', + capabilityChatCompletions: 'Chat Completions', + capabilityEmbeddings: 'Embeddings', responsesStatusAutoSupported: '自动探测:Responses', responsesStatusAutoUnsupported: '自动探测:Chat Completions', responsesStatusAutoUnknown: '自动探测:未探测', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eae5e455..c2136169 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -997,6 +997,7 @@ export interface CodexUsageSnapshot { export type OpenAICompactMode = 'auto' | 'force_on' | 'force_off' export type OpenAIResponsesMode = 'auto' | 'force_responses' | 'force_chat_completions' +export type OpenAIEndpointCapability = 'chat_completions' | 'embeddings' export interface OpenAICompactState { openai_compact_mode?: OpenAICompactMode