为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月 三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。 两层模型: - 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/ dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写 (系统 1 个 key + 每个来源 1 个 key),整体替换语义。 - 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。 后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository 与 service 端口、计费链路集成、管理端与用户端读写接口。 前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、 中英文案。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1281 lines
44 KiB
Go
1281 lines
44 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||
"golang.org/x/sync/singleflight"
|
||
)
|
||
|
||
// 错误定义
|
||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||
// errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时
|
||
// 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。
|
||
var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable")
|
||
|
||
var (
|
||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
|
||
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
|
||
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
|
||
|
||
// user × platform quota(HTTP 429 Too Many Requests + Retry-After header)。
|
||
// 选用 429 而非 403:限额耗尽属于"暂时性资源用尽,重试可恢复"的场景(RFC 6585),
|
||
// 大量 SDK(如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After,
|
||
// 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。
|
||
ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.")
|
||
ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.")
|
||
ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.")
|
||
)
|
||
|
||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||
type subscriptionCacheData struct {
|
||
Status string
|
||
ExpiresAt time.Time
|
||
DailyUsage float64
|
||
WeeklyUsage float64
|
||
MonthlyUsage float64
|
||
Version int64
|
||
}
|
||
|
||
// 缓存写入任务类型
|
||
type cacheWriteKind int
|
||
|
||
const (
|
||
cacheWriteSetBalance cacheWriteKind = iota
|
||
cacheWriteSetSubscription
|
||
cacheWriteUpdateSubscriptionUsage
|
||
cacheWriteDeductBalance
|
||
cacheWriteUpdateRateLimitUsage
|
||
)
|
||
|
||
// 异步缓存写入工作池配置
|
||
//
|
||
// 性能优化说明:
|
||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||
// 3. goroutine 创建/销毁带来额外开销
|
||
//
|
||
// 新实现使用固定大小的工作池:
|
||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
|
||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||
const (
|
||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||
balanceLoadTimeout = 3 * time.Second
|
||
)
|
||
|
||
// cacheWriteTask 缓存写入任务
|
||
type cacheWriteTask struct {
|
||
kind cacheWriteKind
|
||
userID int64
|
||
groupID int64
|
||
apiKeyID int64
|
||
balance float64
|
||
amount float64
|
||
subscriptionData *subscriptionCacheData
|
||
}
|
||
|
||
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
|
||
type apiKeyRateLimitLoader interface {
|
||
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
|
||
}
|
||
|
||
// BillingCacheService 计费缓存服务
|
||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||
type BillingCacheService struct {
|
||
cache BillingCache
|
||
userRepo UserRepository
|
||
subRepo UserSubscriptionRepository
|
||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||
userRPMCache UserRPMCache
|
||
userGroupRateRepo UserGroupRateRepository
|
||
cfg *config.Config
|
||
circuitBreaker *billingCircuitBreaker
|
||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||
|
||
cacheWriteChan chan cacheWriteTask
|
||
cacheWriteWg sync.WaitGroup
|
||
cacheWriteStopOnce sync.Once
|
||
cacheWriteMu sync.RWMutex
|
||
stopped atomic.Bool
|
||
balanceLoadSF singleflight.Group
|
||
quotaLoadSF singleflight.Group
|
||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||
cacheWriteDropFullCount uint64
|
||
cacheWriteDropFullLastLog int64
|
||
cacheWriteDropClosedCount uint64
|
||
cacheWriteDropClosedLastLog int64
|
||
}
|
||
|
||
// NewBillingCacheService 创建计费缓存服务
|
||
func NewBillingCacheService(
|
||
cache BillingCache,
|
||
userRepo UserRepository,
|
||
subRepo UserSubscriptionRepository,
|
||
apiKeyRepo APIKeyRepository,
|
||
userRPMCache UserRPMCache,
|
||
userGroupRateRepo UserGroupRateRepository,
|
||
cfg *config.Config,
|
||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||
) *BillingCacheService {
|
||
svc := &BillingCacheService{
|
||
cache: cache,
|
||
userRepo: userRepo,
|
||
subRepo: subRepo,
|
||
apiKeyRateLimitLoader: apiKeyRepo,
|
||
userRPMCache: userRPMCache,
|
||
userGroupRateRepo: userGroupRateRepo,
|
||
cfg: cfg,
|
||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||
}
|
||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||
svc.startCacheWriteWorkers()
|
||
return svc
|
||
}
|
||
|
||
// Stop 关闭缓存写入工作池
|
||
func (s *BillingCacheService) Stop() {
|
||
s.cacheWriteStopOnce.Do(func() {
|
||
s.stopped.Store(true)
|
||
|
||
s.cacheWriteMu.Lock()
|
||
ch := s.cacheWriteChan
|
||
if ch != nil {
|
||
close(ch)
|
||
}
|
||
s.cacheWriteMu.Unlock()
|
||
|
||
if ch == nil {
|
||
return
|
||
}
|
||
s.cacheWriteWg.Wait()
|
||
|
||
s.cacheWriteMu.Lock()
|
||
if s.cacheWriteChan == ch {
|
||
s.cacheWriteChan = nil
|
||
}
|
||
s.cacheWriteMu.Unlock()
|
||
})
|
||
}
|
||
|
||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
|
||
s.cacheWriteChan = ch
|
||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||
s.cacheWriteWg.Add(1)
|
||
go s.cacheWriteWorker(ch)
|
||
}
|
||
}
|
||
|
||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||
if s.stopped.Load() {
|
||
s.logCacheWriteDrop(task, "closed")
|
||
return false
|
||
}
|
||
|
||
s.cacheWriteMu.RLock()
|
||
defer s.cacheWriteMu.RUnlock()
|
||
|
||
if s.cacheWriteChan == nil {
|
||
s.logCacheWriteDrop(task, "closed")
|
||
return false
|
||
}
|
||
|
||
select {
|
||
case s.cacheWriteChan <- task:
|
||
return true
|
||
default:
|
||
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
|
||
s.logCacheWriteDrop(task, "full")
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
|
||
defer s.cacheWriteWg.Done()
|
||
for task := range ch {
|
||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||
switch task.kind {
|
||
case cacheWriteSetBalance:
|
||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||
case cacheWriteSetSubscription:
|
||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||
case cacheWriteUpdateSubscriptionUsage:
|
||
if s.cache != nil {
|
||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||
}
|
||
}
|
||
case cacheWriteDeductBalance:
|
||
if s.cache != nil {
|
||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||
}
|
||
}
|
||
case cacheWriteUpdateRateLimitUsage:
|
||
if s.cache != nil {
|
||
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
|
||
}
|
||
}
|
||
}
|
||
cancel()
|
||
}
|
||
}
|
||
|
||
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
|
||
func cacheWriteKindName(kind cacheWriteKind) string {
|
||
switch kind {
|
||
case cacheWriteSetBalance:
|
||
return "set_balance"
|
||
case cacheWriteSetSubscription:
|
||
return "set_subscription"
|
||
case cacheWriteUpdateSubscriptionUsage:
|
||
return "update_subscription_usage"
|
||
case cacheWriteDeductBalance:
|
||
return "deduct_balance"
|
||
case cacheWriteUpdateRateLimitUsage:
|
||
return "update_rate_limit_usage"
|
||
default:
|
||
return "unknown"
|
||
}
|
||
}
|
||
|
||
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
|
||
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
|
||
var (
|
||
countPtr *uint64
|
||
lastPtr *int64
|
||
)
|
||
switch reason {
|
||
case "full":
|
||
countPtr = &s.cacheWriteDropFullCount
|
||
lastPtr = &s.cacheWriteDropFullLastLog
|
||
case "closed":
|
||
countPtr = &s.cacheWriteDropClosedCount
|
||
lastPtr = &s.cacheWriteDropClosedLastLog
|
||
default:
|
||
return
|
||
}
|
||
|
||
atomic.AddUint64(countPtr, 1)
|
||
now := time.Now().UnixNano()
|
||
last := atomic.LoadInt64(lastPtr)
|
||
if now-last < int64(cacheWriteDropLogInterval) {
|
||
return
|
||
}
|
||
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
|
||
return
|
||
}
|
||
dropped := atomic.SwapUint64(countPtr, 0)
|
||
if dropped == 0 {
|
||
return
|
||
}
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||
reason,
|
||
dropped,
|
||
cacheWriteDropLogInterval,
|
||
cacheWriteKindName(task.kind),
|
||
task.userID,
|
||
task.groupID,
|
||
)
|
||
}
|
||
|
||
// ============================================
|
||
// 余额缓存方法
|
||
// ============================================
|
||
|
||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||
if s.cache == nil {
|
||
// Redis不可用,直接查询数据库
|
||
return s.getUserBalanceFromDB(ctx, userID)
|
||
}
|
||
|
||
// 尝试从缓存读取
|
||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||
if err == nil {
|
||
return balance, nil
|
||
}
|
||
|
||
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
|
||
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
|
||
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
|
||
defer cancel()
|
||
|
||
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 异步建立缓存
|
||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||
kind: cacheWriteSetBalance,
|
||
userID: userID,
|
||
balance: balance,
|
||
})
|
||
return balance, nil
|
||
})
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
balance, ok := value.(float64)
|
||
if !ok {
|
||
return 0, fmt.Errorf("unexpected balance type: %T", value)
|
||
}
|
||
return balance, nil
|
||
}
|
||
|
||
// getUserBalanceFromDB 从数据库获取用户余额
|
||
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("get user balance: %w", err)
|
||
}
|
||
return user.Balance, nil
|
||
}
|
||
|
||
// setBalanceCache 设置余额缓存
|
||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err)
|
||
}
|
||
}
|
||
|
||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||
}
|
||
|
||
// QueueDeductBalance 异步扣减余额缓存
|
||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
// 队列满时同步回退,避免关键扣减被静默丢弃。
|
||
if s.enqueueCacheWrite(cacheWriteTask{
|
||
kind: cacheWriteDeductBalance,
|
||
userID: userID,
|
||
amount: amount,
|
||
}) {
|
||
return
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||
defer cancel()
|
||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||
}
|
||
}
|
||
|
||
// InvalidateUserBalance 失效用户余额缓存
|
||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ============================================
|
||
// 订阅缓存方法
|
||
// ============================================
|
||
|
||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||
if s.cache == nil {
|
||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||
}
|
||
|
||
// 尝试从缓存读取
|
||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||
if err == nil && cacheData != nil {
|
||
return s.convertFromPortsData(cacheData), nil
|
||
}
|
||
|
||
// 缓存未命中,从数据库读取
|
||
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 异步建立缓存
|
||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||
kind: cacheWriteSetSubscription,
|
||
userID: userID,
|
||
groupID: groupID,
|
||
subscriptionData: data,
|
||
})
|
||
|
||
return data, nil
|
||
}
|
||
|
||
func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
|
||
return &subscriptionCacheData{
|
||
Status: data.Status,
|
||
ExpiresAt: data.ExpiresAt,
|
||
DailyUsage: data.DailyUsage,
|
||
WeeklyUsage: data.WeeklyUsage,
|
||
MonthlyUsage: data.MonthlyUsage,
|
||
Version: data.Version,
|
||
}
|
||
}
|
||
|
||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
|
||
return &SubscriptionCacheData{
|
||
Status: data.Status,
|
||
ExpiresAt: data.ExpiresAt,
|
||
DailyUsage: data.DailyUsage,
|
||
WeeklyUsage: data.WeeklyUsage,
|
||
MonthlyUsage: data.MonthlyUsage,
|
||
Version: data.Version,
|
||
}
|
||
}
|
||
|
||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get subscription: %w", err)
|
||
}
|
||
|
||
return &subscriptionCacheData{
|
||
Status: sub.Status,
|
||
ExpiresAt: sub.ExpiresAt,
|
||
DailyUsage: sub.DailyUsageUSD,
|
||
WeeklyUsage: sub.WeeklyUsageUSD,
|
||
MonthlyUsage: sub.MonthlyUsageUSD,
|
||
Version: sub.UpdatedAt.Unix(),
|
||
}, nil
|
||
}
|
||
|
||
// setSubscriptionCache 设置订阅缓存
|
||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||
if s.cache == nil || data == nil {
|
||
return
|
||
}
|
||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||
}
|
||
}
|
||
|
||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||
}
|
||
|
||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
// 队列满时同步回退,确保订阅用量及时更新。
|
||
if s.enqueueCacheWrite(cacheWriteTask{
|
||
kind: cacheWriteUpdateSubscriptionUsage,
|
||
userID: userID,
|
||
groupID: groupID,
|
||
amount: costUSD,
|
||
}) {
|
||
return
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||
defer cancel()
|
||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||
}
|
||
}
|
||
|
||
// InvalidateSubscription 失效指定订阅缓存
|
||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
|
||
func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ============================================
|
||
// API Key 限速缓存方法
|
||
// ============================================
|
||
|
||
// checkAPIKeyRateLimits checks rate limit windows for an API key.
|
||
// It loads usage from Redis cache (falling back to DB on cache miss),
|
||
// resets expired windows in-memory and triggers async DB reset,
|
||
// and returns an error if any window limit is exceeded.
|
||
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
|
||
if s.cache == nil {
|
||
// No cache: fall back to reading from DB directly
|
||
if s.apiKeyRateLimitLoader == nil {
|
||
return nil
|
||
}
|
||
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||
if err != nil {
|
||
return nil // Don't block requests on DB errors
|
||
}
|
||
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
|
||
data.Window5hStart, data.Window1dStart, data.Window7dStart)
|
||
}
|
||
|
||
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
|
||
if err != nil {
|
||
// Cache miss: load from DB and populate cache
|
||
if s.apiKeyRateLimitLoader == nil {
|
||
return nil
|
||
}
|
||
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||
if dbErr != nil {
|
||
return nil // Don't block requests on DB errors
|
||
}
|
||
// Build cache entry from DB data
|
||
cacheEntry := &APIKeyRateLimitCacheData{
|
||
Usage5h: dbData.Usage5h,
|
||
Usage1d: dbData.Usage1d,
|
||
Usage7d: dbData.Usage7d,
|
||
}
|
||
if dbData.Window5hStart != nil {
|
||
cacheEntry.Window5h = dbData.Window5hStart.Unix()
|
||
}
|
||
if dbData.Window1dStart != nil {
|
||
cacheEntry.Window1d = dbData.Window1dStart.Unix()
|
||
}
|
||
if dbData.Window7dStart != nil {
|
||
cacheEntry.Window7d = dbData.Window7dStart.Unix()
|
||
}
|
||
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
|
||
cacheData = cacheEntry
|
||
}
|
||
|
||
var w5h, w1d, w7d *time.Time
|
||
if cacheData.Window5h > 0 {
|
||
t := time.Unix(cacheData.Window5h, 0)
|
||
w5h = &t
|
||
}
|
||
if cacheData.Window1d > 0 {
|
||
t := time.Unix(cacheData.Window1d, 0)
|
||
w1d = &t
|
||
}
|
||
if cacheData.Window7d > 0 {
|
||
t := time.Unix(cacheData.Window7d, 0)
|
||
w7d = &t
|
||
}
|
||
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
|
||
}
|
||
|
||
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
|
||
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
|
||
needsReset := false
|
||
|
||
// Reset expired windows in-memory for check purposes
|
||
if IsWindowExpired(w5h, RateLimitWindow5h) {
|
||
usage5h = 0
|
||
needsReset = true
|
||
}
|
||
if IsWindowExpired(w1d, RateLimitWindow1d) {
|
||
usage1d = 0
|
||
needsReset = true
|
||
}
|
||
if IsWindowExpired(w7d, RateLimitWindow7d) {
|
||
usage7d = 0
|
||
needsReset = true
|
||
}
|
||
|
||
// Trigger async DB reset if any window expired
|
||
if needsReset {
|
||
keyID := apiKey.ID
|
||
go func() {
|
||
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||
defer cancel()
|
||
if s.apiKeyRateLimitLoader != nil {
|
||
// Use the repo directly - reset then reload cache
|
||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||
}); ok {
|
||
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
|
||
}
|
||
}
|
||
}
|
||
// Invalidate cache so next request loads fresh data
|
||
if s.cache != nil {
|
||
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// Check limits
|
||
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
|
||
return ErrAPIKeyRateLimit5hExceeded
|
||
}
|
||
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
|
||
return ErrAPIKeyRateLimit1dExceeded
|
||
}
|
||
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
|
||
return ErrAPIKeyRateLimit7dExceeded
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
|
||
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
s.enqueueCacheWrite(cacheWriteTask{
|
||
kind: cacheWriteUpdateRateLimitUsage,
|
||
apiKeyID: apiKeyID,
|
||
amount: cost,
|
||
})
|
||
}
|
||
|
||
// IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。
|
||
//
|
||
// 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage,
|
||
// 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。
|
||
// 写延迟通常 < 1ms(本地 Redis),换取 quota 视图实时性的取舍合理。
|
||
//
|
||
// Redis 写失败用 ALERT 级 log;DB 持久化由 caller 单独 goroutine 兜底(gateway_service.go)。
|
||
func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
if platform == "" || cost <= 0 {
|
||
return
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||
defer cancel()
|
||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||
if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil {
|
||
logger.LegacyPrintf("service.billing_cache",
|
||
"ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v",
|
||
userID, platform, cost, err)
|
||
}
|
||
}
|
||
|
||
// ============================================
|
||
// 统一检查方法
|
||
// ============================================
|
||
|
||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||
// 余额模式:检查缓存余额 > 0
|
||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||
// platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。
|
||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error {
|
||
// 简易模式:跳过所有计费检查
|
||
if s.cfg.RunMode == config.RunModeSimple {
|
||
return nil
|
||
}
|
||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||
return ErrBillingServiceUnavailable
|
||
}
|
||
|
||
// 判断计费模式
|
||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||
|
||
if isSubscriptionMode {
|
||
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
|
||
return err
|
||
}
|
||
} else {
|
||
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// user × platform quota 仅在 standard(余额)模式生效;订阅模式豁免
|
||
if !isSubscriptionMode {
|
||
if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// Check API Key rate limits (applies to both billing modes)
|
||
if apiKey != nil && apiKey.HasRateLimits() {
|
||
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
|
||
if err := s.checkRPM(ctx, user, group); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
|
||
//
|
||
// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
|
||
// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
|
||
// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
|
||
// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
|
||
//
|
||
// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
|
||
// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
|
||
func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
|
||
if s == nil || s.userRPMCache == nil || user == nil {
|
||
return nil
|
||
}
|
||
|
||
// ── 第一层:分组级检查(override 或 group.rpm_limit) ──
|
||
if group != nil {
|
||
// 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
|
||
var override *int
|
||
if user.UserGroupRPMOverride != nil {
|
||
override = user.UserGroupRPMOverride
|
||
} else if s.userGroupRateRepo != nil {
|
||
dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
|
||
if err != nil {
|
||
logger.LegacyPrintf(
|
||
"service.billing_cache",
|
||
"Warning: rpm override lookup failed for user=%d group=%d: %v",
|
||
user.ID, group.ID, err,
|
||
)
|
||
} else {
|
||
override = dbOverride
|
||
}
|
||
}
|
||
|
||
if override != nil {
|
||
// override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
|
||
if *override > 0 {
|
||
count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||
if incErr != nil {
|
||
logger.LegacyPrintf(
|
||
"service.billing_cache",
|
||
"Warning: rpm increment (override) failed for user=%d group=%d: %v",
|
||
user.ID, group.ID, incErr,
|
||
)
|
||
// fail-open
|
||
} else if count > *override {
|
||
return ErrGroupRPMExceeded
|
||
}
|
||
}
|
||
// override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
|
||
} else if group.RPMLimit > 0 {
|
||
// 无 override,检查 group.rpm_limit。
|
||
count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
|
||
if err != nil {
|
||
logger.LegacyPrintf(
|
||
"service.billing_cache",
|
||
"Warning: rpm increment (group) failed for user=%d group=%d: %v",
|
||
user.ID, group.ID, err,
|
||
)
|
||
// fail-open
|
||
} else if count > group.RPMLimit {
|
||
return ErrGroupRPMExceeded
|
||
}
|
||
}
|
||
}
|
||
|
||
// ── 第二层:用户级全局硬上限(始终生效) ──
|
||
if user.RPMLimit > 0 {
|
||
count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
|
||
if err != nil {
|
||
logger.LegacyPrintf(
|
||
"service.billing_cache",
|
||
"Warning: rpm increment (user) failed for user=%d: %v",
|
||
user.ID, err,
|
||
)
|
||
return nil // fail-open
|
||
}
|
||
if count > user.RPMLimit {
|
||
return ErrUserRPMExceeded
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkBalanceEligibility 检查余额模式资格
|
||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||
balance, err := s.GetUserBalance(ctx, userID)
|
||
if err != nil {
|
||
if s.circuitBreaker != nil {
|
||
s.circuitBreaker.OnFailure(err)
|
||
}
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err)
|
||
return ErrBillingServiceUnavailable.WithCause(err)
|
||
}
|
||
if s.circuitBreaker != nil {
|
||
s.circuitBreaker.OnSuccess()
|
||
}
|
||
|
||
if balance <= 0 {
|
||
return ErrInsufficientBalance
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkSubscriptionEligibility 检查订阅模式资格
|
||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
|
||
// 获取订阅缓存数据
|
||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||
if err != nil {
|
||
if s.circuitBreaker != nil {
|
||
s.circuitBreaker.OnFailure(err)
|
||
}
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||
return ErrBillingServiceUnavailable.WithCause(err)
|
||
}
|
||
if s.circuitBreaker != nil {
|
||
s.circuitBreaker.OnSuccess()
|
||
}
|
||
|
||
// 检查订阅状态
|
||
if subData.Status != SubscriptionStatusActive {
|
||
return ErrSubscriptionInvalid
|
||
}
|
||
|
||
// 检查是否过期
|
||
if time.Now().After(subData.ExpiresAt) {
|
||
return ErrSubscriptionInvalid
|
||
}
|
||
|
||
// 检查限额(使用传入的Group限额配置)
|
||
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
|
||
return ErrDailyLimitExceeded
|
||
}
|
||
|
||
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
|
||
return ErrWeeklyLimitExceeded
|
||
}
|
||
|
||
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
|
||
return ErrMonthlyLimitExceeded
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type billingCircuitBreakerState int
|
||
|
||
const (
|
||
billingCircuitClosed billingCircuitBreakerState = iota
|
||
billingCircuitOpen
|
||
billingCircuitHalfOpen
|
||
)
|
||
|
||
type billingCircuitBreaker struct {
|
||
mu sync.Mutex
|
||
state billingCircuitBreakerState
|
||
failures int
|
||
openedAt time.Time
|
||
failureThreshold int
|
||
resetTimeout time.Duration
|
||
halfOpenRequests int
|
||
halfOpenRemaining int
|
||
}
|
||
|
||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||
if !cfg.Enabled {
|
||
return nil
|
||
}
|
||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||
if resetTimeout <= 0 {
|
||
resetTimeout = 30 * time.Second
|
||
}
|
||
halfOpen := cfg.HalfOpenRequests
|
||
if halfOpen <= 0 {
|
||
halfOpen = 1
|
||
}
|
||
threshold := cfg.FailureThreshold
|
||
if threshold <= 0 {
|
||
threshold = 5
|
||
}
|
||
return &billingCircuitBreaker{
|
||
state: billingCircuitClosed,
|
||
failureThreshold: threshold,
|
||
resetTimeout: resetTimeout,
|
||
halfOpenRequests: halfOpen,
|
||
}
|
||
}
|
||
|
||
func (b *billingCircuitBreaker) Allow() bool {
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
|
||
switch b.state {
|
||
case billingCircuitClosed:
|
||
return true
|
||
case billingCircuitOpen:
|
||
if time.Since(b.openedAt) < b.resetTimeout {
|
||
return false
|
||
}
|
||
b.state = billingCircuitHalfOpen
|
||
b.halfOpenRemaining = b.halfOpenRequests
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state")
|
||
fallthrough
|
||
case billingCircuitHalfOpen:
|
||
if b.halfOpenRemaining <= 0 {
|
||
return false
|
||
}
|
||
b.halfOpenRemaining--
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||
if b == nil {
|
||
return
|
||
}
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
|
||
switch b.state {
|
||
case billingCircuitOpen:
|
||
return
|
||
case billingCircuitHalfOpen:
|
||
b.state = billingCircuitOpen
|
||
b.openedAt = time.Now()
|
||
b.halfOpenRemaining = 0
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||
return
|
||
default:
|
||
b.failures++
|
||
if b.failures >= b.failureThreshold {
|
||
b.state = billingCircuitOpen
|
||
b.openedAt = time.Now()
|
||
b.halfOpenRemaining = 0
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (b *billingCircuitBreaker) OnSuccess() {
|
||
if b == nil {
|
||
return
|
||
}
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
|
||
previousState := b.state
|
||
previousFailures := b.failures
|
||
|
||
b.state = billingCircuitClosed
|
||
b.failures = 0
|
||
b.halfOpenRemaining = 0
|
||
|
||
// 只有状态真正发生变化时才记录日志
|
||
if previousState != billingCircuitClosed {
|
||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||
} else if previousFailures > 0 {
|
||
logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||
}
|
||
}
|
||
|
||
func circuitStateString(state billingCircuitBreakerState) string {
|
||
switch state {
|
||
case billingCircuitClosed:
|
||
return "closed"
|
||
case billingCircuitOpen:
|
||
return "open"
|
||
case billingCircuitHalfOpen:
|
||
return "half-open"
|
||
default:
|
||
return "unknown"
|
||
}
|
||
}
|
||
|
||
// checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。
|
||
// 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata)。
|
||
// checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。
|
||
//
|
||
// 流程(Redis-first / DB-fallback):
|
||
// 1. 先读 Redis cache;若命中且 SchemaVersion==1,直接用 entry 中的 limits 和 window_start 做校验,
|
||
// 免除 DB 查询。
|
||
// 2. cache MISS 或旧版 entry(SchemaVersion==0)→ 查 DB 回填完整 entry(含 limits/window_start)。
|
||
// 3. Redis 故障(err != nil)→ fail-open,查 DB 做一次性检查,不回填。
|
||
func (s *BillingCacheService) checkUserPlatformQuotaEligibility(
|
||
ctx context.Context,
|
||
userID int64,
|
||
platform string,
|
||
) error {
|
||
if platform == "" || s.userPlatformQuotaRepo == nil {
|
||
return nil
|
||
}
|
||
|
||
// cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。
|
||
// 其他 check* 方法(balance/subscription/rate-limit)也有类似守卫。
|
||
var (
|
||
entry *UserPlatformQuotaCacheEntry
|
||
ok bool
|
||
cacheErr error
|
||
)
|
||
if s.cache != nil {
|
||
entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform)
|
||
} else {
|
||
// 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查
|
||
cacheErr = errBillingCacheUnavailable
|
||
}
|
||
|
||
// --- cache HIT with current schema → 直接用 entry,不查 DB ---
|
||
if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 {
|
||
now := time.Now()
|
||
dailyUsage := entry.DailyUsageUSD
|
||
weeklyUsage := entry.WeeklyUsageUSD
|
||
monthlyUsage := entry.MonthlyUsageUSD
|
||
// 若窗口已更新(DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较,
|
||
// 同时记录新窗口起点用于后续刷新 cache entry。
|
||
// 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力,
|
||
// 持久化数据始终正确。
|
||
windowExpired := false
|
||
newDailyStart := entry.DailyWindowStart
|
||
newWeeklyStart := entry.WeeklyWindowStart
|
||
newMonthlyStart := entry.MonthlyWindowStart
|
||
if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) {
|
||
dailyUsage = 0
|
||
windowExpired = true
|
||
dayStart := timezone.StartOfDay(now)
|
||
newDailyStart = &dayStart
|
||
}
|
||
if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) {
|
||
weeklyUsage = 0
|
||
windowExpired = true
|
||
weekStart := timezone.StartOfWeek(now)
|
||
newWeeklyStart = &weekStart
|
||
}
|
||
if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) {
|
||
monthlyUsage = 0
|
||
windowExpired = true
|
||
monthStart := now
|
||
newMonthlyStart = &monthStart
|
||
}
|
||
// 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis(而非 Delete)。
|
||
// 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到
|
||
// EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。
|
||
// 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。
|
||
// 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。
|
||
// 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。
|
||
// 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。
|
||
if windowExpired && s.cache != nil {
|
||
refreshed := &UserPlatformQuotaCacheEntry{
|
||
DailyUsageUSD: dailyUsage,
|
||
WeeklyUsageUSD: weeklyUsage,
|
||
MonthlyUsageUSD: monthlyUsage,
|
||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||
DailyLimitUSD: entry.DailyLimitUSD,
|
||
WeeklyLimitUSD: entry.WeeklyLimitUSD,
|
||
MonthlyLimitUSD: entry.MonthlyLimitUSD,
|
||
DailyWindowStart: newDailyStart,
|
||
WeeklyWindowStart: newWeeklyStart,
|
||
MonthlyWindowStart: newMonthlyStart,
|
||
}
|
||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil {
|
||
logger.LegacyPrintf("service.billing_cache",
|
||
"Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v",
|
||
userID, platform, setErr)
|
||
}
|
||
setCancel()
|
||
}
|
||
if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||
}
|
||
if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||
}
|
||
if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// --- cache MISS、旧版 entry 或 Redis 故障 → 查 DB(singleflight 合并并发回源)---
|
||
// 使用 DoChan 而非 Do:avoid sharing the first caller's ctx among all dedupe followers.
|
||
// 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。
|
||
sfKey := strconv.FormatInt(userID, 10) + ":" + platform
|
||
ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) {
|
||
// 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx,
|
||
// 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。
|
||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer bgCancel()
|
||
return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform)
|
||
})
|
||
var (
|
||
v any
|
||
dbErr error
|
||
)
|
||
select {
|
||
case res := <-ch:
|
||
v, dbErr = res.Val, res.Err
|
||
case <-ctx.Done():
|
||
// 当前 caller 的 ctx 被取消:fail-open,不阻断 (此请求已无意义)。
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err())
|
||
return nil
|
||
}
|
||
if dbErr != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr)
|
||
return nil
|
||
}
|
||
rec, _ := v.(*UserPlatformQuotaRecord)
|
||
if rec == nil {
|
||
return nil
|
||
}
|
||
|
||
now := time.Now()
|
||
dailyUsage := rec.DailyUsageUSD
|
||
weeklyUsage := rec.WeeklyUsageUSD
|
||
monthlyUsage := rec.MonthlyUsageUSD
|
||
if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) {
|
||
dailyUsage = 0
|
||
}
|
||
if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) {
|
||
weeklyUsage = 0
|
||
}
|
||
if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) {
|
||
monthlyUsage = 0
|
||
}
|
||
|
||
// Redis 故障时 fail-open:不回填,直接用 DB 数据做一次性检查
|
||
if cacheErr != nil {
|
||
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||
}
|
||
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||
}
|
||
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// cache MISS 或旧版 entry → 回填完整 entry(含 limits 和 window_start)
|
||
newEntry := &UserPlatformQuotaCacheEntry{
|
||
DailyUsageUSD: dailyUsage,
|
||
WeeklyUsageUSD: weeklyUsage,
|
||
MonthlyUsageUSD: monthlyUsage,
|
||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||
DailyLimitUSD: rec.DailyLimitUSD,
|
||
WeeklyLimitUSD: rec.WeeklyLimitUSD,
|
||
MonthlyLimitUSD: rec.MonthlyLimitUSD,
|
||
DailyWindowStart: rec.DailyWindowStart,
|
||
WeeklyWindowStart: rec.WeeklyWindowStart,
|
||
MonthlyWindowStart: rec.MonthlyWindowStart,
|
||
}
|
||
if s.cache != nil {
|
||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||
// 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms,
|
||
// 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败,
|
||
// 让下一次 preflight 仍然 MISS 并击穿到 DB(高并发下增大 DB 压力)。
|
||
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil {
|
||
logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr)
|
||
}
|
||
setCancel()
|
||
}
|
||
|
||
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||
}
|
||
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||
}
|
||
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
|
||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// withWindowResetsMetadata 给 quota error 附加 window_resets_at metadata(RFC3339)。
|
||
func withWindowResetsMetadata(err error, resetAt time.Time) error {
|
||
appErr, ok := err.(*infraerrors.ApplicationError)
|
||
if !ok || appErr == nil {
|
||
return err
|
||
}
|
||
return appErr.WithMetadata(map[string]string{
|
||
"window_resets_at": resetAt.Format(time.RFC3339),
|
||
})
|
||
}
|
||
|
||
// nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。
|
||
// 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。
|
||
func nextDailyReset(now time.Time) time.Time {
|
||
return timezone.StartOfDay(now).AddDate(0, 0, 1)
|
||
}
|
||
|
||
// nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。
|
||
// 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。
|
||
func nextWeeklyReset(now time.Time) time.Time {
|
||
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
|
||
}
|
||
|
||
// nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间(start + 30d)。
|
||
// start 为 nil(未初始化)或已过期(now-start >= 30d,与 monthlyQuotaWindowExpired 同口径)时
|
||
// 退化为 now+30d:过期窗口会在下次 increment 时重置为 now,下次重置即 now+30d;
|
||
// 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。
|
||
func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time {
|
||
if start == nil || now.Sub(*start) >= 30*24*time.Hour {
|
||
return now.Add(30 * 24 * time.Hour)
|
||
}
|
||
return start.Add(30 * 24 * time.Hour)
|
||
}
|
||
|
||
// quotaWindowExpired 判断窗口是否已过期:start 为 nil(未初始化)或在 currWindowStart 之前视为已过期。
|
||
func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool {
|
||
if start == nil {
|
||
return true
|
||
}
|
||
return start.Before(currWindowStart)
|
||
}
|
||
|
||
// monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。
|
||
// 过期条件:now - start >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。
|
||
// start 为 nil 时视为已过期(未初始化窗口)。
|
||
func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool {
|
||
if start == nil {
|
||
return true
|
||
}
|
||
return now.Sub(*start) >= 30*24*time.Hour
|
||
}
|