sub2api/backend/internal/service/gateway_tier_fallback.go

134 lines
3.8 KiB
Go

package service
import (
"context"
"errors"
)
// accountTierLevel maps an account type to a scheduling tier:
//
// 0 = subscription (OAuth / SetupToken) — tried first
// 1 = API Key — first fallback
// 2 = Bedrock — last resort
//
// Accounts with an unknown type fall into tier 0 so they participate in the
// primary selection and do not vanish silently.
func accountTierLevel(account *Account) int {
if account == nil {
return 0
}
switch account.Type {
case AccountTypeAPIKey:
return 1
case AccountTypeBedrock:
return 2
default: // OAuth, SetupToken, or unknown
return 0
}
}
// enableTierFallbackChain reports whether the cross-tier fallback chain is
// enabled in config (default false).
func (s *GatewayService) enableTierFallbackChain() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.Scheduling.EnableTierFallbackChain
}
// selectAccountWithTierFallback tries Anthropic accounts in tier order:
// tier 0 (OAuth/SetupToken subscription) → tier 1 (API Key) → tier 2 (Bedrock).
//
// Sticky sessions are honored within the chain: if the session-bound account is
// in a tier that still has capacity it is returned immediately; otherwise the
// session binding is cleared and the chain proceeds from tier 0.
func (s *GatewayService) selectAccountWithTierFallback(
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
excludedIDs map[int64]struct{},
) (*Account, error) {
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAnthropic, false)
if err != nil {
return nil, err
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// Build per-tier candidate lists (pointers into `accounts`).
const numTiers = 3
tierCandidates := [numTiers][]*Account{}
for i := range accounts {
acc := &accounts[i]
if acc.Platform != PlatformAnthropic {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !s.isAccountSchedulableForSelection(acc) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForQuota(acc) {
continue
}
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
tier := accountTierLevel(acc)
if tier < numTiers {
tierCandidates[tier] = append(tierCandidates[tier], acc)
}
}
cfg := s.schedulingConfig()
selectionMode := cfg.FallbackSelectionMode
// Check sticky session: if the bound account is a valid candidate, use it.
if sessionHash != "" && s.cache != nil {
accountID, cacheErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if cacheErr == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
for tier := 0; tier < numTiers; tier++ {
for _, acc := range tierCandidates[tier] {
if acc.ID != accountID {
continue
}
if shouldClearStickySession(acc, requestedModel) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
break
}
if s.isAccountSchedulableForWindowCost(ctx, acc, true) &&
s.isAccountSchedulableForRPM(ctx, acc, true) {
return acc, nil
}
}
}
}
}
}
// Try each tier in order.
for tier := 0; tier < numTiers; tier++ {
candidates := tierCandidates[tier]
if len(candidates) == 0 {
continue
}
s.sortCandidatesForFallback(candidates, false, selectionMode)
result, acquired, _ := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, false)
if acquired && result != nil {
return result.Account, nil
}
}
return nil, errors.New("no available accounts in any tier")
}