feat(rpm): add token bucket smoothing for RPM rate limiting
- New RPMTokenBucketService: per-account continuous-refill token buckets (rate = rpm/60 tokens/sec, capacity = rpm). No new dependencies. - GatewayService.AcquireRPMToken() delegates to the bucket service. - Gateway handler inserts RPM token wait BEFORE wrapReleaseOnDone in both Gemini and Anthropic dispatch paths; timeout returns 429 and releases slot. - Config: gateway.rpm_smoothing.enabled (default false) + max_wait_ms (default 5000). - 7 unit tests covering: immediate acquire, zero RPM, timeout, wait+refill, context cancel, account isolation, bucket reset on RPM change.
This commit is contained in:
parent
5c8c15cdb1
commit
95814974de
@ -180,7 +180,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, rpmTokenBucketService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
|
||||
@ -481,6 +481,10 @@ type GatewayConfig struct {
|
||||
// UserMessageQueue: 用户消息串行队列配置
|
||||
// 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟
|
||||
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
|
||||
|
||||
// RPMSmoothing: RPM 令牌桶平滑配置
|
||||
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
||||
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
||||
}
|
||||
|
||||
type GatewayAntigravityLSWorkerConfig struct {
|
||||
@ -535,6 +539,23 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RPMSmoothingConfig RPM 令牌桶平滑配置
|
||||
type RPMSmoothingConfig struct {
|
||||
// Enabled: 是否启用 RPM 令牌桶平滑(默认 false)
|
||||
// 启用后,当账号 RPM 配额耗尽时,请求最多等待 MaxWaitMS 毫秒,而非立即返回 429。
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// MaxWaitMS: 等待令牌的最大时间(毫秒),超时后返回 429(默认 5000)
|
||||
MaxWaitMS int `mapstructure:"max_wait_ms"`
|
||||
}
|
||||
|
||||
// MaxWait returns the configured wait duration, defaulting to 5s.
|
||||
func (c *RPMSmoothingConfig) MaxWait() time.Duration {
|
||||
if c.MaxWaitMS <= 0 {
|
||||
return 5 * time.Second
|
||||
}
|
||||
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
|
||||
@ -383,6 +383,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmCancel()
|
||||
if rpmErr != nil {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
@ -605,6 +620,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// RPM 令牌桶平滑:在让出请求前等待令牌(最多 MaxWaitMS 毫秒)
|
||||
// 必须在 wrapReleaseOnDone 之前执行,以便超时时能安全释放原始槽位。
|
||||
if h.cfg.Gateway.RPMSmoothing.Enabled && account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
|
||||
rpmWaitCtx, rpmCancel := context.WithTimeout(c.Request.Context(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmErr := h.gatewayService.AcquireRPMToken(rpmWaitCtx, account.ID, account.GetBaseRPM(), h.cfg.Gateway.RPMSmoothing.MaxWait())
|
||||
rpmCancel()
|
||||
if rpmErr != nil {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
|
||||
@ -558,7 +558,8 @@ type GatewayService struct {
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
|
||||
rpmTokenBucket *RPMTokenBucketService // RPM 令牌桶平滑(可选,由配置开关控制)
|
||||
userGroupRateResolver *userGroupRateResolver
|
||||
userGroupRateCache *gocache.Cache
|
||||
userGroupRateSF singleflight.Group
|
||||
@ -597,6 +598,7 @@ func NewGatewayService(
|
||||
digestStore *DigestSessionStore,
|
||||
settingService *SettingService,
|
||||
tlsFPProfileService *TLSFingerprintProfileService,
|
||||
rpmTokenBucketSvc *RPMTokenBucketService,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@ -623,6 +625,7 @@ func NewGatewayService(
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
rpmTokenBucket: rpmTokenBucketSvc,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
@ -2433,6 +2436,15 @@ func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int6
|
||||
return err
|
||||
}
|
||||
|
||||
// AcquireRPMToken consumes one RPM token for the given account, waiting up to maxWait if needed.
|
||||
// Returns nil immediately when RPM smoothing is not configured or the account has no RPM limit.
|
||||
func (s *GatewayService) AcquireRPMToken(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||
if s.rpmTokenBucket == nil {
|
||||
return nil
|
||||
}
|
||||
return s.rpmTokenBucket.AcquireWithWait(ctx, accountID, rpm, maxWait)
|
||||
}
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||
|
||||
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
120
backend/internal/service/rpm_token_bucket_service.go
Normal file
@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
|
||||
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
|
||||
|
||||
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
|
||||
// When an account's RPM budget is exhausted, callers can wait up to a configured
|
||||
// deadline instead of receiving an immediate 429. The bucket refills continuously
|
||||
// at rpm/60 tokens per second so requests are distributed evenly over time.
|
||||
type RPMTokenBucketService struct {
|
||||
buckets sync.Map // map[int64]*rpmEntry
|
||||
}
|
||||
|
||||
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
|
||||
func NewRPMTokenBucketService() *RPMTokenBucketService {
|
||||
return &RPMTokenBucketService{}
|
||||
}
|
||||
|
||||
type rpmEntry struct {
|
||||
bucket *tokenBucket
|
||||
rpm int
|
||||
}
|
||||
|
||||
// getBucket returns (or creates) the token bucket for accountID.
|
||||
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
|
||||
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
|
||||
if v, ok := s.buckets.Load(accountID); ok {
|
||||
e := v.(*rpmEntry)
|
||||
if e.rpm == rpm {
|
||||
return e.bucket
|
||||
}
|
||||
// RPM limit changed — replace with a fresh bucket.
|
||||
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||
s.buckets.Store(accountID, fresh)
|
||||
return fresh.bucket
|
||||
}
|
||||
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
||||
actual, _ := s.buckets.LoadOrStore(accountID, entry)
|
||||
return actual.(*rpmEntry).bucket
|
||||
}
|
||||
|
||||
// AcquireWithWait attempts to consume one token for the given account.
|
||||
// It blocks up to maxWait for a token to become available.
|
||||
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
|
||||
// or ctx.Err() if the context is cancelled.
|
||||
// If rpm <= 0 the call returns immediately with nil.
|
||||
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
||||
if rpm <= 0 {
|
||||
return nil
|
||||
}
|
||||
bucket := s.getBucket(accountID, rpm)
|
||||
deadline := time.Now().Add(maxWait)
|
||||
|
||||
for {
|
||||
ok, waitDur := bucket.tryAcquire()
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 || waitDur > remaining {
|
||||
return ErrRPMWaitTimeout
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(waitDur):
|
||||
// token may be available now; retry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tokenBucket is a continuous-refill token bucket for a single account.
|
||||
type tokenBucket struct {
|
||||
mu sync.Mutex
|
||||
tokens float64
|
||||
maxTokens float64
|
||||
rateSec float64 // tokens refilled per second = rpm / 60
|
||||
lastFill time.Time
|
||||
}
|
||||
|
||||
func newTokenBucket(rpm int) *tokenBucket {
|
||||
max := float64(rpm)
|
||||
return &tokenBucket{
|
||||
tokens: max,
|
||||
maxTokens: max,
|
||||
rateSec: float64(rpm) / 60.0,
|
||||
lastFill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
|
||||
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
|
||||
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(b.lastFill).Seconds()
|
||||
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
|
||||
b.lastFill = now
|
||||
|
||||
if b.tokens >= 1.0 {
|
||||
b.tokens -= 1.0
|
||||
return true, 0
|
||||
}
|
||||
|
||||
deficit := 1.0 - b.tokens
|
||||
waitSecs := deficit / b.rateSec
|
||||
return false, time.Duration(waitSecs * float64(time.Second))
|
||||
}
|
||||
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
108
backend/internal/service/rpm_token_bucket_service_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRPMTokenBucket_ImmediateAcquireWhenFull(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
// Bucket starts full (rpm=60 tokens). First 60 calls should succeed immediately.
|
||||
for i := 0; i < 60; i++ {
|
||||
err := svc.AcquireWithWait(ctx, 1, 60, 0)
|
||||
require.NoError(t, err, "call %d should succeed immediately", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_ZeroRPMAlwaysOK(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
err := svc.AcquireWithWait(context.Background(), 42, 0, 0)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_TimeoutWhenExhausted(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// rpm=1 → 1 token/minute. One call drains the bucket.
|
||||
err := svc.AcquireWithWait(ctx, 99, 1, 5*time.Second)
|
||||
require.NoError(t, err, "first call should succeed")
|
||||
|
||||
// Second call: bucket empty, wait time ≈ 60s which exceeds maxWait=50ms.
|
||||
start := time.Now()
|
||||
err = svc.AcquireWithWait(ctx, 99, 1, 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
assert.ErrorIs(t, err, ErrRPMWaitTimeout)
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "should timeout quickly, not block")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_WaitsAndSucceeds(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// rpm=120 → refill rate = 2 tokens/second. Drain the bucket fully.
|
||||
for i := 0; i < 120; i++ {
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 7, 120, 0))
|
||||
}
|
||||
|
||||
// Next call needs to wait ~500ms for the next token. Give it 2s.
|
||||
start := time.Now()
|
||||
err := svc.AcquireWithWait(ctx, 7, 120, 2*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
require.NoError(t, err, "should succeed after waiting for refill")
|
||||
assert.Greater(t, elapsed, 100*time.Millisecond, "should have actually waited")
|
||||
assert.Less(t, elapsed, 1500*time.Millisecond, "should not wait excessively long")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_ContextCancellation(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
|
||||
// rpm=120 → refill = 2 tokens/second → next token in ~500ms after draining.
|
||||
// maxWait = 2s (longer than 500ms refill wait) so the code blocks in time.After(~500ms).
|
||||
// Context is cancelled after 30ms, which is shorter than the 500ms wait, so ctx.Done fires first.
|
||||
for i := 0; i < 120; i++ {
|
||||
require.NoError(t, svc.AcquireWithWait(context.Background(), 55, 120, 0))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
err := svc.AcquireWithWait(ctx, 55, 120, 2*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.Less(t, elapsed, 200*time.Millisecond, "should respect context cancellation promptly")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_DifferentAccountsAreIsolated(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// Drain account 1 (rpm=1).
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 1, 1, 0))
|
||||
|
||||
// Account 2 has its own bucket and should succeed immediately.
|
||||
err := svc.AcquireWithWait(ctx, 2, 1, 0)
|
||||
assert.NoError(t, err, "different account should have an independent bucket")
|
||||
}
|
||||
|
||||
func TestRPMTokenBucket_RPMChangeReplacesBucket(t *testing.T) {
|
||||
svc := NewRPMTokenBucketService()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create bucket with rpm=1 and drain it.
|
||||
require.NoError(t, svc.AcquireWithWait(ctx, 10, 1, 0))
|
||||
// Bucket now empty with rpm=1.
|
||||
|
||||
// Changing RPM to 60 should reset the bucket to full (60 tokens).
|
||||
err := svc.AcquireWithWait(ctx, 10, 60, 0)
|
||||
assert.NoError(t, err, "new RPM should cause bucket recreation")
|
||||
}
|
||||
@ -424,6 +424,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewBillingCacheService,
|
||||
NewAnnouncementService,
|
||||
NewAdminService,
|
||||
NewRPMTokenBucketService,
|
||||
NewGatewayService,
|
||||
ProvideSoraMediaStorage,
|
||||
ProvideSoraMediaCleanupService,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user