diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 17f0d47e..9805bf8a 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -127,7 +127,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { for { reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", @@ -135,6 +135,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { diff --git a/backend/internal/handler/openai_embeddings.go b/backend/internal/handler/openai_embeddings.go index bbb67044..81713f7f 100644 --- a/backend/internal/handler/openai_embeddings.go +++ b/backend/internal/handler/openai_embeddings.go @@ -107,7 +107,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { routingStart := time.Now() for { - selection, _, err := h.gatewayService.SelectAccountWithScheduler( + selection, _, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", @@ -115,6 +115,7 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportHTTPSSE, + service.OpenAIEndpointCapabilityEmbeddings, false, ) if err != nil { @@ -140,13 +141,6 @@ func (h *OpenAIGatewayHandler) Embeddings(c *gin.Context) { return } account := selection.Account - if account.Type != service.AccountTypeAPIKey { - if selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - failedAccountIDs[account.ID] = struct{}{} - continue - } setOpsSelectedAccount(c, account.ID, account.Platform) accountReleaseFunc, accountAcquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, "", selection, false, &streamStarted, reqLog) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a51eee86..1d661748 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -266,7 +266,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, previousResponseID, @@ -274,6 +274,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, requireCompact, ) if err != nil { @@ -675,7 +676,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { currentRoutingModel = effectiveMappedModel } reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( c.Request.Context(), apiKey.GroupID, "", // no previous_response_id @@ -683,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { currentRoutingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { @@ -1273,7 +1275,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { for { reqLog.Debug("openai.websocket_account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForCapability( ctx, apiKey.GroupID, previousResponseID, @@ -1281,6 +1283,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { reqModel, failedAccountIDs, service.OpenAIUpstreamTransportResponsesWebsocketV2, + service.OpenAIEndpointCapabilityChatCompletions, false, ) if err != nil { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index d488aa75..e3ca9c5d 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -66,6 +66,15 @@ type Account struct { modelMappingCacheRawSig uint64 } +type OpenAIEndpointCapability string + +const ( + OpenAIEndpointCapabilityChatCompletions OpenAIEndpointCapability = "chat_completions" + OpenAIEndpointCapabilityEmbeddings OpenAIEndpointCapability = "embeddings" +) + +const openAIEndpointCapabilitiesCredentialKey = "openai_capabilities" + type TempUnschedulableRule struct { ErrorCode int `json:"error_code"` Keywords []string `json:"keywords"` @@ -1122,6 +1131,80 @@ func (a *Account) GetOpenAISessionID() string { return strings.TrimSpace(a.GetExtraString("openai_session_id")) } +func (a *Account) SupportsOpenAIEndpointCapability(capability OpenAIEndpointCapability) bool { + if a == nil { + return false + } + if capability == "" { + return true + } + if !a.IsOpenAI() { + return false + } + switch capability { + case OpenAIEndpointCapabilityChatCompletions: + case OpenAIEndpointCapabilityEmbeddings: + if a.Type != AccountTypeAPIKey { + return false + } + default: + return false + } + + configured, found := a.openAIEndpointCapabilitySet() + if !found { + return true + } + return configured[string(capability)] +} + +func (a *Account) openAIEndpointCapabilitySet() (map[string]bool, bool) { + if a == nil || a.Credentials == nil { + return nil, false + } + raw, found := a.Credentials[openAIEndpointCapabilitiesCredentialKey] + if !found || raw == nil { + return nil, false + } + + result := make(map[string]bool) + add := func(value string) { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return + } + result[value] = true + } + + switch capabilities := raw.(type) { + case []any: + for _, item := range capabilities { + if value, ok := item.(string); ok { + add(value) + } + } + case []string: + for _, value := range capabilities { + add(value) + } + case map[string]any: + for key, value := range capabilities { + enabled, ok := value.(bool) + if ok && enabled { + add(key) + } + } + case map[string]bool: + for key, enabled := range capabilities { + if enabled { + add(key) + } + } + } + + return result, true +} + func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool { if !a.IsOpenAI() { return false diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index a8ac391a..1eca08b1 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -44,6 +44,7 @@ type OpenAIAccountScheduleRequest struct { PreviousResponseID string RequestedModel string RequiredTransport OpenAIUpstreamTransport + RequiredCapability OpenAIEndpointCapability RequiredImageCapability OpenAIImagesCapability RequireCompact bool ExcludedIDs map[int64]struct{} @@ -263,12 +264,13 @@ func (s *defaultOpenAIAccountScheduler) Select( previousResponseID := strings.TrimSpace(req.PreviousResponseID) if previousResponseID != "" { - selection, err := s.service.SelectAccountByPreviousResponseID( + selection, err := s.service.selectAccountByPreviousResponseIDForCapability( ctx, req.GroupID, previousResponseID, req.RequestedModel, req.ExcludedIDs, + req.RequiredCapability, req.RequireCompact, ) if err != nil { @@ -363,7 +365,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact) + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact, req.RequiredCapability) if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) { _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil @@ -791,11 +793,11 @@ func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder( compactBlocked := false for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } @@ -930,11 +932,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( cfg := s.service.schedulingConfig() // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false, req.RequiredCapability) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } @@ -981,7 +983,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) { return false } - return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) + return accountSupportsOpenAICapabilities(account, req.RequiredCapability, req.RequiredImageCapability) } func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { @@ -1104,7 +1106,21 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact) + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", "", requireCompact) +} + +func (s *OpenAIGatewayService) SelectAccountWithSchedulerForCapability( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, + requiredCapability OpenAIEndpointCapability, + requireCompact bool, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, requiredCapability, "", requireCompact) } func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( @@ -1115,13 +1131,13 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( excludedIDs map[int64]struct{}, requiredCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false) + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", requiredCapability, false) if err == nil && selection != nil && selection.Account != nil { return selection, decision, nil } // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) if requiredCapability == OpenAIImagesCapabilityNative { - return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false) + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, "", OpenAIImagesCapabilityBasic, false) } return selection, decision, err } @@ -1134,6 +1150,7 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( requestedModel string, excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, + requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability, requireCompact bool, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { @@ -1144,14 +1161,14 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability) if err != nil { return nil, decision, err } if selection == nil || selection.Account == nil { return selection, decision, nil } - if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) { + if accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) { return selection, decision, nil } if selection.ReleaseFunc != nil { @@ -1169,14 +1186,15 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) for { - selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact) + selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact, requiredCapability) if err != nil { return nil, decision, err } if selection == nil || selection.Account == nil { return selection, decision, nil } - if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) { + if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) && + accountSupportsOpenAICapabilities(selection.Account, requiredCapability, requiredImageCapability) { return selection, decision, nil } if selection.ReleaseFunc != nil { @@ -1213,12 +1231,21 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( PreviousResponseID: previousResponseID, RequestedModel: requestedModel, RequiredTransport: requiredTransport, + RequiredCapability: requiredCapability, RequiredImageCapability: requiredImageCapability, RequireCompact: requireCompact, ExcludedIDs: excludedIDs, }) } +func accountSupportsOpenAICapabilities(account *Account, requiredCapability OpenAIEndpointCapability, requiredImageCapability OpenAIImagesCapability) bool { + if account == nil { + return false + } + return account.SupportsOpenAIEndpointCapability(requiredCapability) && + account.SupportsOpenAIImageCapability(requiredImageCapability) +} + func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} { if len(excludedIDs) == 0 { return nil diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index ba20ee5f..fedf7e9c 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -393,6 +393,64 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_Require require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10110) + accounts := []Account{ + { + ID: 36031, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + }, + { + ID: 36032, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "", + "", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36032), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { resetOpenAIAdvancedSchedulerSettingCacheForTest() @@ -458,6 +516,141 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPrev require.True(t, decision.StickyPreviousHit) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyAccount(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10111) + accounts := []Account{ + { + ID: 37011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + }, + { + ID: 37012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "", + "", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37012), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 1, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_Enabled_EmbeddingsSkipsChatOnlyStickyBindings(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10112) + accounts := []Account{ + { + ID: 37021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37022, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cfg := newSchedulerTestOpenAIWSV2Config() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_embeddings": 37021, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_embeddings_chat_only", 37021, time.Hour)) + + selection, decision, err := svc.SelectAccountWithSchedulerForCapability( + ctx, + &groupID, + "resp_embeddings_chat_only", + "session_hash_embeddings", + "text-embedding-3-small", + nil, + OpenAIUpstreamTransportHTTPSSE, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37022), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) + require.False(t, decision.StickySessionHit) + require.Equal(t, int64(37022), cache.sessionBindings["openai:session_hash_embeddings"]) +} + func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { resetOpenAIAdvancedSchedulerSettingCacheForTest() diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f93cc221..77587f69 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1279,7 +1279,7 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0) + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0, "") } // noAvailableOpenAISelectionError builds the standard "no account available" error @@ -1312,13 +1312,16 @@ func openAICompactSupportTier(account *Account) int { // isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / // compact-support checks used during account selection. -func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool) bool { +func isOpenAIAccountEligibleForRequest(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) bool { if account == nil || !account.IsOpenAI() || !account.IsSchedulableForModelWithContext(ctx, requestedModel) { return false } if requestedModel != "" && !account.IsModelSupported(requestedModel) { return false } + if !account.SupportsOpenAIEndpointCapability(requiredCapability) { + return false + } if requireCompact && openAICompactSupportTier(account) == 0 { return false } @@ -1366,7 +1369,7 @@ func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedMode return upstreamModel } -func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) { +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) (*Account, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1376,7 +1379,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability); account != nil { return account, nil } @@ -1389,7 +1392,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact) + selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact, requiredCapability) if selected == nil { return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) @@ -1414,7 +1417,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64, requiredCapability OpenAIEndpointCapability) *Account { if sessionHash == "" { return nil } @@ -1446,14 +1449,14 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 验证账号是否可用于当前请求 // Verify account is usable for current request - if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(account) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil @@ -1477,7 +1480,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // Returns nil if no available account. The second return reports whether at // least one candidate was filtered out solely because it lacks compact support // (only meaningful when requireCompact=true). -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*Account, bool) { var selected *Account selectedCompactTier := -1 compactBlocked := false @@ -1492,11 +1495,11 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i continue } - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false, requiredCapability) if fresh == nil { continue } @@ -1573,10 +1576,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false) + return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, "") } -func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) { +func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, requiredCapability OpenAIEndpointCapability) (*AccountSelectionResult, error) { if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { slog.Warn("channel pricing restriction blocked request", "group_id", derefGroupID(groupID), @@ -1593,7 +1596,7 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID, requiredCapability) if err != nil { return nil, err } @@ -1646,8 +1649,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex if clearSticky { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } - if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false) { - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + if !clearSticky && isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, false, requiredCapability) { + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact, requiredCapability) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else if s.isOpenAIAccountRuntimeBlocked(account) { @@ -1691,15 +1694,12 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !isOpenAIAccountEligibleForRequest(ctx, acc, requestedModel, false, requiredCapability) { continue } if s.isOpenAIAccountRuntimeBlocked(acc) { continue } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue - } if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) { continue } @@ -1779,11 +1779,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex } for _, item := range selectionOrder { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1813,11 +1813,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex ordered = prioritizeOpenAICompactAccounts(ordered) } for _, acc := range ordered { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1858,11 +1858,11 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex candidates = prioritizeOpenAICompactAccounts(candidates) } for _, acc := range candidates { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false, requiredCapability) if fresh == nil { continue } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact, requiredCapability) if fresh == nil { continue } @@ -1910,7 +1910,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account { if account == nil { return nil } @@ -1924,7 +1924,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. fresh = current } - if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, fresh, requestedModel, requireCompact, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(fresh) { @@ -1933,12 +1933,12 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. return fresh } -func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool, requiredCapability OpenAIEndpointCapability) *Account { if account == nil { return nil } if s.schedulerSnapshot == nil || s.accountRepo == nil { - if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, account, requestedModel, requireCompact, requiredCapability) { return nil } return account @@ -1948,7 +1948,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if err != nil || latest == nil { return nil } - if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact) { + if !isOpenAIAccountEligibleForRequest(ctx, latest, requestedModel, requireCompact, requiredCapability) { return nil } if s.isOpenAIAccountRuntimeBlocked(latest) { diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 854e9f6d..a87e96c1 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -413,6 +413,79 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) } +func TestAccountSupportsOpenAIEndpointCapability(t *testing.T) { + t.Run("OpenAI APIKey 默认兼容 chat 和 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("OpenAI OAuth 默认仅兼容 chat", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式列表支持同时声明 chat 和 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions", "embeddings"}, + }, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式列表只声明 chat 时不支持 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + } + + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("显式 map 支持单独关闭 chat 并开启 embeddings", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "openai_capabilities": map[string]any{ + "chat_completions": false, + "embeddings": true, + }, + }, + } + + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityChatCompletions)) + require.True(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapabilityEmbeddings)) + }) + + t.Run("未知能力不应默认放行", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + require.False(t, account.SupportsOpenAIEndpointCapability(OpenAIEndpointCapability("unknown"))) + }) +} + func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) { require.Equal(t, "https://image-upstream.example/v1/images/generations", diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 4005a921..c8b28a46 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -268,6 +268,52 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky( require.Equal(t, int64(21), selection.WaitPlan.AccountID) } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_CapabilityMismatchKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(25) + account := Account{ + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "openai_capabilities": []any{"chat_completions"}, + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_capability", account.ID, time.Hour)) + + selection, err := svc.selectAccountByPreviousResponseIDForCapability( + ctx, + &groupID, + "resp_prev_capability", + "text-embedding-3-small", + nil, + OpenAIEndpointCapabilityEmbeddings, + false, + ) + require.NoError(t, err) + require.Nil(t, selection) + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_capability") + require.NoError(t, getErr) + require.Equal(t, account.ID, boundAccountID) +} + func newOpenAIWSV2TestConfig() *config.Config { cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index b8e558ae..5fd5cffc 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -3987,6 +3987,18 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, +) (*AccountSelectionResult, error) { + return s.selectAccountByPreviousResponseIDForCapability(ctx, groupID, previousResponseID, requestedModel, excludedIDs, "", requireCompact) +} + +func (s *OpenAIGatewayService) selectAccountByPreviousResponseIDForCapability( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredCapability OpenAIEndpointCapability, + requireCompact bool, ) (*AccountSelectionResult, error) { if s == nil { return nil, nil @@ -4027,12 +4039,31 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) - if account == nil { - _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + if !account.SupportsOpenAIEndpointCapability(requiredCapability) { return nil, nil } - // 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。 + if s.schedulerSnapshot != nil && s.accountRepo != nil { + latest, latestErr := s.accountRepo.GetByID(ctx, account.ID) + if latestErr != nil || latest == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if shouldClearStickySession(latest, requestedModel) || !latest.IsOpenAI() || !latest.IsSchedulable() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + return nil, nil + } + if !latest.SupportsOpenAIEndpointCapability(requiredCapability) { + return nil, nil + } + if s.isOpenAIAccountRuntimeBlocked(latest) { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + account = latest + } if requireCompact && openAICompactSupportTier(account) == 0 { _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) return nil, nil diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 331295f7..665c4695 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2679,7 +2679,7 @@
{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}
+{{ t('admin.accounts.openai.endpointCapabilitiesDesc') }}
+