feat(rpm): add token bucket smoothing for RPM rate limiting

- New RPMTokenBucketService: per-account continuous-refill token buckets
  (rate = rpm/60 tokens/sec, capacity = rpm). No new dependencies.
- GatewayService.AcquireRPMToken() delegates to the bucket service.
- Gateway handler inserts RPM token wait BEFORE wrapReleaseOnDone in both
  Gemini and Anthropic dispatch paths; timeout returns 429 and releases slot.
- Config: gateway.rpm_smoothing.enabled (default false) + max_wait_ms (default 5000).
- 7 unit tests covering: immediate acquire, zero RPM, timeout, wait+refill,
  context cancel, account isolation, bucket reset on RPM change.
This commit is contained in:
win 2026-04-29 01:22:54 +08:00
parent 5c8c15cdb1
commit 95814974de
7 changed files with 295 additions and 2 deletions

View File

@ -180,7 +180,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, rpmTokenBucketService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)

View File

@ -481,6 +481,10 @@ type GatewayConfig struct {
// UserMessageQueue: 用户消息串行队列配置
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
// RPMSmoothing: RPM 令牌桶平滑配置
// 启用后RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
}
type GatewayAntigravityLSWorkerConfig struct {
@ -535,6 +539,23 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
return ""
}
// RPMSmoothingConfig RPM 令牌桶平滑配置
type RPMSmoothingConfig struct {
// Enabled: 是否启用 RPM 令牌桶平滑(默认 false
// 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。
Enabled bool `mapstructure:"enabled"`
// MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429默认 5000
MaxWaitMS int `mapstructure:"max_wait_ms"`
}
// MaxWait returns the configured wait duration, defaulting to 5s.
func (c *RPMSmoothingConfig) MaxWait() time.Duration {
if c.MaxWaitMS <= 0 {
return 5 * time.Second
}
return time.Duration(c.MaxWaitMS) * time.Millisecond
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {

View File

@ -383,6 +383,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmCancel()
if rpmErr != nil {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
@ -605,6 +620,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
}
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
rpmCancel()
if rpmErr != nil {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)

View File

@ -558,7 +558,8 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
userGroupRateResolver *userGroupRateResolver
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
@ -597,6 +598,7 @@ func NewGatewayService(
digestStore *DigestSessionStore,
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
rpmTokenBucketSvc *RPMTokenBucketService,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@ -623,6 +625,7 @@ func NewGatewayService(
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
rpmTokenBucket: rpmTokenBucketSvc,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute),
@ -2433,6 +2436,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
return err
}
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
if s.rpmTokenBucket == nil {
return nil
}
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash

View File

@ -0,0 +1,120 @@
package service
import (
"context"
"errors"
"math"
"sync"
"time"
)
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
// When an account's RPM budget is exhausted, callers can wait up to a configured
// deadline instead of receiving an immediate 429. The bucket refills continuously
// at rpm/60 tokens per second so requests are distributed evenly over time.
type RPMTokenBucketService struct {
buckets sync.Map // map[int64]*rpmEntry
}
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
func NewRPMTokenBucketService() *RPMTokenBucketService {
return &RPMTokenBucketService{}
}
type rpmEntry struct {
bucket *tokenBucket
rpm int
}
// getBucket returns (or creates) the token bucket for accountID.
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
if v, ok := s.buckets.Load(accountID); ok {
e := v.(*rpmEntry)
if e.rpm == rpm {
return e.bucket
}
// RPM limit changed — replace with a fresh bucket.
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
s.buckets.Store(accountID, fresh)
return fresh.bucket
}
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
actual, _ := s.buckets.LoadOrStore(accountID, entry)
return actual.(*rpmEntry).bucket
}
// AcquireWithWait attempts to consume one token for the given account.
// It blocks up to maxWait for a token to become available.
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
// or ctx.Err() if the context is cancelled.
// If rpm <= 0 the call returns immediately with nil.
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
if rpm <= 0 {
return nil
}
bucket := s.getBucket(accountID, rpm)
deadline := time.Now().Add(maxWait)
for {
ok, waitDur := bucket.tryAcquire()
if ok {
return nil
}
remaining := time.Until(deadline)
if remaining <= 0 || waitDur > remaining {
return ErrRPMWaitTimeout
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(waitDur):
// token may be available now; retry
}
}
}
// tokenBucket is a continuous-refill token bucket for a single account.
type tokenBucket struct {
mu sync.Mutex
tokens float64
maxTokens float64
rateSec float64 // tokens refilled per second = rpm / 60
lastFill time.Time
}
func newTokenBucket(rpm int) *tokenBucket {
max := float64(rpm)
return &tokenBucket{
tokens: max,
maxTokens: max,
rateSec: float64(rpm) / 60.0,
lastFill: time.Now(),
}
}
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
elapsed := now.Sub(b.lastFill).Seconds()
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
b.lastFill = now
if b.tokens >= 1.0 {
b.tokens -= 1.0
return true, 0
}
deficit := 1.0 - b.tokens
waitSecs := deficit / b.rateSec
return false, time.Duration(waitSecs * float64(time.Second))
}

View File

@ -0,0 +1,108 @@
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately.
for i := 0; i < 60; i++ {
err := svc.AcquireWithWait(ctx, 1, 60, 0)
require.NoError(t, err, "call %d should succeed immediately", i+1)
}
}
func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) {
svc := NewRPMTokenBucketService()
err := svc.AcquireWithWait(context.Background(), 42, 0, 0)
assert.NoError(t, err)
}
func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// rpm=1 → 1 token/minute. One call drains the bucket.
err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second)
require.NoError(t, err, "first call should succeed")
// Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms.
start := time.Now()
err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond)
elapsed := time.Since(start)
assert.ErrorIs(t, err, ErrRPMWaitTimeout)
assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block")
}
func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully.
for i := 0; i < 120; i++ {
require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0))
}
// Next call needs to wait ~500ms for the next token. Give it 2s.
start := time.Now()
err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second)
elapsed := time.Since(start)
require.NoError(t, err, "should succeed after waiting for refill")
assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited")
assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long")
}
func TestRPMTokenBucket_ContextCancellation(t *testing.T) {
svc := NewRPMTokenBucketService()
// rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining.
// maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms).
// Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first.
for i := 0; i < 120; i++ {
require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0))
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
start := time.Now()
err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second)
elapsed := time.Since(start)
assert.ErrorIs(t, err, context.Canceled)
assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly")
}
func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Drain account 1 (rpm=1).
require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0))
// Account 2 has its own bucket and should succeed immediately.
err := svc.AcquireWithWait(ctx, 2, 1, 0)
assert.NoError(t, err, "different account should have an independent bucket")
}
func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) {
svc := NewRPMTokenBucketService()
ctx := context.Background()
// Create bucket with rpm=1 and drain it.
require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0))
// Bucket now empty with rpm=1.
// Changing RPM to 60 should reset the bucket to full (60 tokens).
err := svc.AcquireWithWait(ctx, 10, 60, 0)
assert.NoError(t, err, "new RPM should cause bucket recreation")
}

View File

@ -424,6 +424,7 @@ var ProviderSet = wire.NewSet(
NewBillingCacheService,
NewAnnouncementService,
NewAdminService,
NewRPMTokenBucketService,
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,