- New RPMTokenBucketService: per-account continuous-refill token buckets (rate = rpm/60 tokens/sec, capacity = rpm). No new dependencies. - GatewayService.AcquireRPMToken() delegates to the bucket service. - Gateway handler inserts RPM token wait BEFORE wrapReleaseOnDone in both Gemini and Anthropic dispatch paths; timeout returns 429 and releases slot. - Config: gateway.rpm_smoothing.enabled (default false) + max_wait_ms (default 5000). - 7 unit tests covering: immediate acquire, zero RPM, timeout, wait+refill, context cancel, account isolation, bucket reset on RPM change.
121 lines
3.4 KiB
Go
121 lines
3.4 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"math"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ErrRPMWaitTimeout is returned when AcquireWithWait cannot obtain a token within maxWait.
|
|
var ErrRPMWaitTimeout = errors.New("rpm smoothing: timed out waiting for rate limit slot")
|
|
|
|
// RPMTokenBucketService provides per-account token buckets for RPM smoothing.
|
|
// When an account's RPM budget is exhausted, callers can wait up to a configured
|
|
// deadline instead of receiving an immediate 429. The bucket refills continuously
|
|
// at rpm/60 tokens per second so requests are distributed evenly over time.
|
|
type RPMTokenBucketService struct {
|
|
buckets sync.Map // map[int64]*rpmEntry
|
|
}
|
|
|
|
// NewRPMTokenBucketService creates a ready-to-use RPMTokenBucketService.
|
|
func NewRPMTokenBucketService() *RPMTokenBucketService {
|
|
return &RPMTokenBucketService{}
|
|
}
|
|
|
|
type rpmEntry struct {
|
|
bucket *tokenBucket
|
|
rpm int
|
|
}
|
|
|
|
// getBucket returns (or creates) the token bucket for accountID.
|
|
// If the account's RPM limit has changed since the bucket was created, the bucket is replaced.
|
|
func (s *RPMTokenBucketService) getBucket(accountID int64, rpm int) *tokenBucket {
|
|
if v, ok := s.buckets.Load(accountID); ok {
|
|
e := v.(*rpmEntry)
|
|
if e.rpm == rpm {
|
|
return e.bucket
|
|
}
|
|
// RPM limit changed — replace with a fresh bucket.
|
|
fresh := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
|
s.buckets.Store(accountID, fresh)
|
|
return fresh.bucket
|
|
}
|
|
entry := &rpmEntry{rpm: rpm, bucket: newTokenBucket(rpm)}
|
|
actual, _ := s.buckets.LoadOrStore(accountID, entry)
|
|
return actual.(*rpmEntry).bucket
|
|
}
|
|
|
|
// AcquireWithWait attempts to consume one token for the given account.
|
|
// It blocks up to maxWait for a token to become available.
|
|
// Returns nil on success, ErrRPMWaitTimeout if the deadline is exceeded,
|
|
// or ctx.Err() if the context is cancelled.
|
|
// If rpm <= 0 the call returns immediately with nil.
|
|
func (s *RPMTokenBucketService) AcquireWithWait(ctx context.Context, accountID int64, rpm int, maxWait time.Duration) error {
|
|
if rpm <= 0 {
|
|
return nil
|
|
}
|
|
bucket := s.getBucket(accountID, rpm)
|
|
deadline := time.Now().Add(maxWait)
|
|
|
|
for {
|
|
ok, waitDur := bucket.tryAcquire()
|
|
if ok {
|
|
return nil
|
|
}
|
|
|
|
remaining := time.Until(deadline)
|
|
if remaining <= 0 || waitDur > remaining {
|
|
return ErrRPMWaitTimeout
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(waitDur):
|
|
// token may be available now; retry
|
|
}
|
|
}
|
|
}
|
|
|
|
// tokenBucket is a continuous-refill token bucket for a single account.
|
|
type tokenBucket struct {
|
|
mu sync.Mutex
|
|
tokens float64
|
|
maxTokens float64
|
|
rateSec float64 // tokens refilled per second = rpm / 60
|
|
lastFill time.Time
|
|
}
|
|
|
|
func newTokenBucket(rpm int) *tokenBucket {
|
|
max := float64(rpm)
|
|
return &tokenBucket{
|
|
tokens: max,
|
|
maxTokens: max,
|
|
rateSec: float64(rpm) / 60.0,
|
|
lastFill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// tryAcquire refills the bucket based on elapsed time, then attempts to consume one token.
|
|
// Returns (true, 0) on success, or (false, waitDur) indicating how long until a token is available.
|
|
func (b *tokenBucket) tryAcquire() (bool, time.Duration) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(b.lastFill).Seconds()
|
|
b.tokens = math.Min(b.maxTokens, b.tokens+elapsed*b.rateSec)
|
|
b.lastFill = now
|
|
|
|
if b.tokens >= 1.0 {
|
|
b.tokens -= 1.0
|
|
return true, 0
|
|
}
|
|
|
|
deficit := 1.0 - b.tokens
|
|
waitSecs := deficit / b.rateSec
|
|
return false, time.Duration(waitSecs * float64(time.Second))
|
|
}
|