diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ffb53780..43ebc292 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -180,7 +180,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + rpmTokenBucketService := service.NewRPMTokenBucketService() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, rpmTokenBucketService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 270a0b98..3d5e151f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -481,6 +481,10 @@ type GatewayConfig struct { // UserMessageQueue: 用户消息串行队列配置 // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` + + // RPMSmoothing: RPM 令牌桶平滑配置 + // 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429 + RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"` } type GatewayAntigravityLSWorkerConfig struct { @@ -535,6 +539,23 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string { return "" } +// RPMSmoothingConfig RPM 令牌桶平滑配置 +type RPMSmoothingConfig struct { + // Enabled: 是否启用 RPM 令牌桶平滑(默认 false) + // 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。 + Enabled bool `mapstructure:"enabled"` + // MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429(默认 5000) + MaxWaitMS int `mapstructure:"max_wait_ms"` +} + +// MaxWait returns the configured wait duration, defaulting to 5s. +func (c *RPMSmoothingConfig) MaxWait() time.Duration { + if c.MaxWaitMS <= 0 { + return 5 * time.Second + } + return time.Duration(c.MaxWaitMS) * time.Millisecond +} + // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 type GatewayOpenAIWSConfig struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9..babb9448 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -383,6 +383,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } + // RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒) + // 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。 + if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmCancel() + if rpmErr != nil { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) + return + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) @@ -605,6 +620,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } + // RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒) + // 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。 + if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait()) + rpmCancel() + if rpmErr != nil { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted) + return + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 107e5086..dfe3fe34 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -558,7 +558,8 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制) userGroupRateResolver *userGroupRateResolver userGroupRateCache *gocache.Cache userGroupRateSF singleflight.Group @@ -597,6 +598,7 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + rpmTokenBucketSvc *RPMTokenBucketService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -623,6 +625,7 @@ func NewGatewayService( claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, rpmCache: rpmCache, + rpmTokenBucket: rpmTokenBucketSvc, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), settingService: settingService, modelsListCache: gocache.New(modelsListTTL, time.Minute), @@ -2433,6 +2436,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6 return err } +// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed. +// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit. +func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error { + if s.rpmTokenBucket == nil { + return nil + } + return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait) +} + // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) diff --git a/backend/internal/service/rpm_token_bucket_service.go b/backend/internal/service/rpm_token_bucket_service.go new file mode 100644 index 00000000..dfea5798 --- /dev/null +++ b/backend/internal/service/rpm_token_bucket_service.go @@ -0,0 +1,120 @@ +package service + +import ( + "context" + "errors" + "math" + "sync" + "time" +) + +// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait. +var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot") + +// RPMTokenBucketService provides per-account token buckets for RPM smoothing. +// When an account's RPM budget is exhausted, callers can wait up to a configured +// deadline instead of receiving an immediate 429. The bucket refills continuously +// at rpm/60 tokens per second so requests are distributed evenly over time. +type RPMTokenBucketService struct { + buckets sync.Map // map[int64]*rpmEntry +} + +// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService. +func NewRPMTokenBucketService() *RPMTokenBucketService { + return &RPMTokenBucketService{} +} + +type rpmEntry struct { + bucket *tokenBucket + rpm int +} + +// getBucket returns (or creates) the token bucket for accountID. +// If the account's RPM limit has changed since the bucket was created, the bucket is replaced. +func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket { + if v, ok := s.buckets.Load(accountID); ok { + e := v.(*rpmEntry) + if e.rpm == rpm { + return e.bucket + } + // RPM limit changed — replace with a fresh bucket. + fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)} + s.buckets.Store(accountID, fresh) + return fresh.bucket + } + entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)} + actual, _ := s.buckets.LoadOrStore(accountID, entry) + return actual.(*rpmEntry).bucket +} + +// AcquireWithWait attempts to consume one token for the given account. +// It blocks up to maxWait for a token to become available. +// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded, +// or ctx.Err() if the context is cancelled. +// If rpm <= 0 the call returns immediately with nil. +func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error { + if rpm <= 0 { + return nil + } + bucket := s.getBucket(accountID, rpm) + deadline := time.Now().Add(maxWait) + + for { + ok, waitDur := bucket.tryAcquire() + if ok { + return nil + } + + remaining := time.Until(deadline) + if remaining <= 0 || waitDur > remaining { + return ErrRPMWaitTimeout + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitDur): + // token may be available now; retry + } + } +} + +// tokenBucket is a continuous-refill token bucket for a single account. +type tokenBucket struct { + mu sync.Mutex + tokens float64 + maxTokens float64 + rateSec float64 // tokens refilled per second = rpm / 60 + lastFill time.Time +} + +func newTokenBucket(rpm int) *tokenBucket { + max := float64(rpm) + return &tokenBucket{ + tokens: max, + maxTokens: max, + rateSec: float64(rpm) / 60.0, + lastFill: time.Now(), + } +} + +// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token. +// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available. +func (b *tokenBucket) tryAcquire() (bool, time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(b.lastFill).Seconds() + b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec) + b.lastFill = now + + if b.tokens >= 1.0 { + b.tokens -= 1.0 + return true, 0 + } + + deficit := 1.0 - b.tokens + waitSecs := deficit / b.rateSec + return false, time.Duration(waitSecs * float64(time.Second)) +} diff --git a/backend/internal/service/rpm_token_bucket_service_test.go b/backend/internal/service/rpm_token_bucket_service_test.go new file mode 100644 index 00000000..1710c15d --- /dev/null +++ b/backend/internal/service/rpm_token_bucket_service_test.go @@ -0,0 +1,108 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + // Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately. + for i := 0; i < 60; i++ { + err := svc.AcquireWithWait(ctx, 1, 60, 0) + require.NoError(t, err, "call %d should succeed immediately", i+1) + } +} + +func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) { + svc := NewRPMTokenBucketService() + err := svc.AcquireWithWait(context.Background(), 42, 0, 0) + assert.NoError(t, err) +} + +func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // rpm=1 → 1 token/minute. One call drains the bucket. + err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second) + require.NoError(t, err, "first call should succeed") + + // Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms. + start := time.Now() + err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond) + elapsed := time.Since(start) + assert.ErrorIs(t, err, ErrRPMWaitTimeout) + assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block") +} + +func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully. + for i := 0; i < 120; i++ { + require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0)) + } + + // Next call needs to wait ~500ms for the next token. Give it 2s. + start := time.Now() + err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second) + elapsed := time.Since(start) + require.NoError(t, err, "should succeed after waiting for refill") + assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited") + assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long") +} + +func TestRPMTokenBucket_ContextCancellation(t *testing.T) { + svc := NewRPMTokenBucketService() + + // rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining. + // maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms). + // Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first. + for i := 0; i < 120; i++ { + require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0)) + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second) + elapsed := time.Since(start) + assert.ErrorIs(t, err, context.Canceled) + assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly") +} + +func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // Drain account 1 (rpm=1). + require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0)) + + // Account 2 has its own bucket and should succeed immediately. + err := svc.AcquireWithWait(ctx, 2, 1, 0) + assert.NoError(t, err, "different account should have an independent bucket") +} + +func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) { + svc := NewRPMTokenBucketService() + ctx := context.Background() + + // Create bucket with rpm=1 and drain it. + require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0)) + // Bucket now empty with rpm=1. + + // Changing RPM to 60 should reset the bucket to full (60 tokens). + err := svc.AcquireWithWait(ctx, 10, 60, 0) + assert.NoError(t, err, "new RPM should cause bucket recreation") +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index abf437d5..2ce138e0 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -424,6 +424,7 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAnnouncementService, NewAdminService, + NewRPMTokenBucketService, NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService,