package service import ( "context" "fmt" "strconv" "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "golang.org/x/sync/singleflight" ) // 错误定义 // 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时 // 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。 var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable") var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") // user × platform quota(HTTP 429 Too Many Requests + Retry-After header)。 // 选用 429 而非 403:限额耗尽属于"暂时性资源用尽,重试可恢复"的场景(RFC 6585), // 大量 SDK(如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After, // 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。 ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.") ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.") ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) type subscriptionCacheData struct { Status string ExpiresAt time.Time DailyUsage float64 WeeklyUsage float64 MonthlyUsage float64 Version int64 } // 缓存写入任务类型 type cacheWriteKind int const ( cacheWriteSetBalance cacheWriteKind = iota cacheWriteSetSubscription cacheWriteUpdateSubscriptionUsage cacheWriteDeductBalance cacheWriteUpdateRateLimitUsage ) // 异步缓存写入工作池配置 // // 性能优化说明: // 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题: // 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine // 2. 无法控制并发数量,可能导致 Redis 连接耗尽 // 3. goroutine 创建/销毁带来额外开销 // // 新实现使用固定大小的工作池: // 1. 预创建 10 个 worker goroutine,避免频繁创建销毁 // 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值 // 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警 // 4. 统一超时控制,避免慢操作阻塞工作池 const ( cacheWriteWorkerCount = 10 // 工作协程数量 cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 balanceLoadTimeout = 3 * time.Second ) // cacheWriteTask 缓存写入任务 type cacheWriteTask struct { kind cacheWriteKind userID int64 groupID int64 apiKeyID int64 balance float64 amount float64 subscriptionData *subscriptionCacheData } // apiKeyRateLimitLoader defines the interface for loading rate limit data from DB. type apiKeyRateLimitLoader interface { GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error) } // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { cache BillingCache userRepo UserRepository subRepo UserSubscriptionRepository apiKeyRateLimitLoader apiKeyRateLimitLoader userRPMCache UserRPMCache userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker userPlatformQuotaRepo UserPlatformQuotaRepository cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once cacheWriteMu sync.RWMutex stopped atomic.Bool balanceLoadSF singleflight.Group quotaLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 cacheWriteDropClosedCount uint64 cacheWriteDropClosedLastLog int64 } // NewBillingCacheService 创建计费缓存服务 func NewBillingCacheService( cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, userRPMCache UserRPMCache, userGroupRateRepo UserGroupRateRepository, cfg *config.Config, userPlatformQuotaRepo UserPlatformQuotaRepository, ) *BillingCacheService { svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, apiKeyRateLimitLoader: apiKeyRepo, userRPMCache: userRPMCache, userGroupRateRepo: userGroupRateRepo, cfg: cfg, userPlatformQuotaRepo: userPlatformQuotaRepo, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() return svc } // Stop 关闭缓存写入工作池 func (s *BillingCacheService) Stop() { s.cacheWriteStopOnce.Do(func() { s.stopped.Store(true) s.cacheWriteMu.Lock() ch := s.cacheWriteChan if ch != nil { close(ch) } s.cacheWriteMu.Unlock() if ch == nil { return } s.cacheWriteWg.Wait() s.cacheWriteMu.Lock() if s.cacheWriteChan == ch { s.cacheWriteChan = nil } s.cacheWriteMu.Unlock() }) } func (s *BillingCacheService) startCacheWriteWorkers() { ch := make(chan cacheWriteTask, cacheWriteBufferSize) s.cacheWriteChan = ch for i := 0; i < cacheWriteWorkerCount; i++ { s.cacheWriteWg.Add(1) go s.cacheWriteWorker(ch) } } // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { if s.stopped.Load() { s.logCacheWriteDrop(task, "closed") return false } s.cacheWriteMu.RLock() defer s.cacheWriteMu.RUnlock() if s.cacheWriteChan == nil { s.logCacheWriteDrop(task, "closed") return false } select { case s.cacheWriteChan <- task: return true default: // 队列满时不阻塞主流程,交由调用方决定是否同步回退。 s.logCacheWriteDrop(task, "full") return false } } func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { defer s.cacheWriteWg.Done() for task := range ch { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) switch task.kind { case cacheWriteSetBalance: s.setBalanceCache(ctx, task.userID, task.balance) case cacheWriteSetSubscription: s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData) case cacheWriteUpdateSubscriptionUsage: if s.cache != nil { if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) } } case cacheWriteDeductBalance: if s.cache != nil { if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } case cacheWriteUpdateRateLimitUsage: if s.cache != nil { if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err) } } } cancel() } } // cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。 func cacheWriteKindName(kind cacheWriteKind) string { switch kind { case cacheWriteSetBalance: return "set_balance" case cacheWriteSetSubscription: return "set_subscription" case cacheWriteUpdateSubscriptionUsage: return "update_subscription_usage" case cacheWriteDeductBalance: return "deduct_balance" case cacheWriteUpdateRateLimitUsage: return "update_rate_limit_usage" default: return "unknown" } } // logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。 func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) { var ( countPtr *uint64 lastPtr *int64 ) switch reason { case "full": countPtr = &s.cacheWriteDropFullCount lastPtr = &s.cacheWriteDropFullLastLog case "closed": countPtr = &s.cacheWriteDropClosedCount lastPtr = &s.cacheWriteDropClosedLastLog default: return } atomic.AddUint64(countPtr, 1) now := time.Now().UnixNano() last := atomic.LoadInt64(lastPtr) if now-last < int64(cacheWriteDropLogInterval) { return } if !atomic.CompareAndSwapInt64(lastPtr, last, now) { return } dropped := atomic.SwapUint64(countPtr, 0) if dropped == 0 { return } logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", reason, dropped, cacheWriteDropLogInterval, cacheWriteKindName(task.kind), task.userID, task.groupID, ) } // ============================================ // 余额缓存方法 // ============================================ // GetUserBalance 获取用户余额(优先从缓存读取) func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) { if s.cache == nil { // Redis不可用,直接查询数据库 return s.getUserBalanceFromDB(ctx, userID) } // 尝试从缓存读取 balance, err := s.cache.GetUserBalance(ctx, userID) if err == nil { return balance, nil } // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。 value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) { loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout) defer cancel() balance, err := s.getUserBalanceFromDB(loadCtx, userID) if err != nil { return nil, err } // 异步建立缓存 _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetBalance, userID: userID, balance: balance, }) return balance, nil }) if err != nil { return 0, err } balance, ok := value.(float64) if !ok { return 0, fmt.Errorf("unexpected balance type: %T", value) } return balance, nil } // getUserBalanceFromDB 从数据库获取用户余额 func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return 0, fmt.Errorf("get user balance: %w", err) } return user.Balance, nil } // setBalanceCache 设置余额缓存 func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) { if s.cache == nil { return } if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err) } } // DeductBalanceCache 扣减余额缓存(同步调用) func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { if s.cache == nil { return nil } return s.cache.DeductUserBalance(ctx, userID, amount) } // QueueDeductBalance 异步扣减余额缓存 func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { if s.cache == nil { return } // 队列满时同步回退,避免关键扣减被静默丢弃。 if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteDeductBalance, userID: userID, amount: amount, }) { return } ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err) } } // InvalidateUserBalance 失效用户余额缓存 func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { if s.cache == nil { return nil } if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err) return err } return nil } // ============================================ // 订阅缓存方法 // ============================================ // GetSubscriptionStatus 获取订阅状态(优先从缓存读取) func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { if s.cache == nil { return s.getSubscriptionFromDB(ctx, userID, groupID) } // 尝试从缓存读取 cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID) if err == nil && cacheData != nil { return s.convertFromPortsData(cacheData), nil } // 缓存未命中,从数据库读取 data, err := s.getSubscriptionFromDB(ctx, userID, groupID) if err != nil { return nil, err } // 异步建立缓存 _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetSubscription, userID: userID, groupID: groupID, subscriptionData: data, }) return data, nil } func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData { return &subscriptionCacheData{ Status: data.Status, ExpiresAt: data.ExpiresAt, DailyUsage: data.DailyUsage, WeeklyUsage: data.WeeklyUsage, MonthlyUsage: data.MonthlyUsage, Version: data.Version, } } func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData { return &SubscriptionCacheData{ Status: data.Status, ExpiresAt: data.ExpiresAt, DailyUsage: data.DailyUsage, WeeklyUsage: data.WeeklyUsage, MonthlyUsage: data.MonthlyUsage, Version: data.Version, } } // getSubscriptionFromDB 从数据库获取订阅数据 func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) { sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) if err != nil { return nil, fmt.Errorf("get subscription: %w", err) } return &subscriptionCacheData{ Status: sub.Status, ExpiresAt: sub.ExpiresAt, DailyUsage: sub.DailyUsageUSD, WeeklyUsage: sub.WeeklyUsageUSD, MonthlyUsage: sub.MonthlyUsageUSD, Version: sub.UpdatedAt.Unix(), }, nil } // setSubscriptionCache 设置订阅缓存 func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) { if s.cache == nil || data == nil { return } if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) } } // UpdateSubscriptionUsage 更新订阅用量缓存(同步调用) func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { if s.cache == nil { return nil } return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) } // QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { if s.cache == nil { return } // 队列满时同步回退,确保订阅用量及时更新。 if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteUpdateSubscriptionUsage, userID: userID, groupID: groupID, amount: costUSD, }) { return } ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) } } // InvalidateSubscription 失效指定订阅缓存 func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { if s.cache == nil { return nil } if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) return err } return nil } // InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key. func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { if s.cache == nil { return nil } if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err) return err } return nil } // ============================================ // API Key 限速缓存方法 // ============================================ // checkAPIKeyRateLimits checks rate limit windows for an API key. // It loads usage from Redis cache (falling back to DB on cache miss), // resets expired windows in-memory and triggers async DB reset, // and returns an error if any window limit is exceeded. func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error { if s.cache == nil { // No cache: fall back to reading from DB directly if s.apiKeyRateLimitLoader == nil { return nil } data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) if err != nil { return nil // Don't block requests on DB errors } return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d, data.Window5hStart, data.Window1dStart, data.Window7dStart) } cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID) if err != nil { // Cache miss: load from DB and populate cache if s.apiKeyRateLimitLoader == nil { return nil } dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) if dbErr != nil { return nil // Don't block requests on DB errors } // Build cache entry from DB data cacheEntry := &APIKeyRateLimitCacheData{ Usage5h: dbData.Usage5h, Usage1d: dbData.Usage1d, Usage7d: dbData.Usage7d, } if dbData.Window5hStart != nil { cacheEntry.Window5h = dbData.Window5hStart.Unix() } if dbData.Window1dStart != nil { cacheEntry.Window1d = dbData.Window1dStart.Unix() } if dbData.Window7dStart != nil { cacheEntry.Window7d = dbData.Window7dStart.Unix() } _ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry) cacheData = cacheEntry } var w5h, w1d, w7d *time.Time if cacheData.Window5h > 0 { t := time.Unix(cacheData.Window5h, 0) w5h = &t } if cacheData.Window1d > 0 { t := time.Unix(cacheData.Window1d, 0) w1d = &t } if cacheData.Window7d > 0 { t := time.Unix(cacheData.Window7d, 0) w7d = &t } return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d) } // evaluateRateLimits checks usage against limits, triggering async resets for expired windows. func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error { needsReset := false // Reset expired windows in-memory for check purposes if IsWindowExpired(w5h, RateLimitWindow5h) { usage5h = 0 needsReset = true } if IsWindowExpired(w1d, RateLimitWindow1d) { usage1d = 0 needsReset = true } if IsWindowExpired(w7d, RateLimitWindow7d) { usage7d = 0 needsReset = true } // Trigger async DB reset if any window expired if needsReset { keyID := apiKey.ID go func() { resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if s.apiKeyRateLimitLoader != nil { // Use the repo directly - reset then reload cache if loader, ok := s.apiKeyRateLimitLoader.(interface { ResetRateLimitWindows(ctx context.Context, id int64) error }); ok { if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err) } } } // Invalidate cache so next request loads fresh data if s.cache != nil { if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil { logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err) } } }() } // Check limits if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h { return ErrAPIKeyRateLimit5hExceeded } if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d { return ErrAPIKeyRateLimit1dExceeded } if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d { return ErrAPIKeyRateLimit7dExceeded } return nil } // QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache. func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) { if s.cache == nil { return } s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteUpdateRateLimitUsage, apiKeyID: apiKeyID, amount: cost, }) } // IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。 // // 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage, // 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。 // 写延迟通常 < 1ms(本地 Redis),换取 quota 视图实时性的取舍合理。 // // Redis 写失败用 ALERT 级 log;DB 持久化由 caller 单独 goroutine 兜底(gateway_service.go)。 func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) { if s.cache == nil { return } if platform == "" || cost <= 0 { return } ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil { logger.LegacyPrintf("service.billing_cache", "ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v", userID, platform, cost, err) } } // ============================================ // 统一检查方法 // ============================================ // CheckBillingEligibility 检查用户是否有资格发起请求 // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) // platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。 func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error { // 简易模式:跳过所有计费检查 if s.cfg.RunMode == config.RunModeSimple { return nil } if s.circuitBreaker != nil && !s.circuitBreaker.Allow() { return ErrBillingServiceUnavailable } // 判断计费模式 isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil if isSubscriptionMode { if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil { return err } } else { if err := s.checkBalanceEligibility(ctx, user.ID); err != nil { return err } } // user × platform quota 仅在 standard(余额)模式生效;订阅模式豁免 if !isSubscriptionMode { if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil { return err } } // Check API Key rate limits (applies to both billing modes) if apiKey != nil && apiKey.HasRateLimits() { if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { return err } } // RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。 if err := s.checkRPM(ctx, user, group); err != nil { return err } return nil } // checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝: // // 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。 // override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。 // 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。 // 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。 // // 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。 // Redis 故障一律 fail-open(打 warning,不阻塞业务)。 func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error { if s == nil || s.userRPMCache == nil || user == nil { return nil } // ── 第一层:分组级检查(override 或 group.rpm_limit) ── if group != nil { // 解析 override:优先从 auth cache snapshot,nil 时回退 DB。 var override *int if user.UserGroupRPMOverride != nil { override = user.UserGroupRPMOverride } else if s.userGroupRateRepo != nil { dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID) if err != nil { logger.LegacyPrintf( "service.billing_cache", "Warning: rpm override lookup failed for user=%d group=%d: %v", user.ID, group.ID, err, ) } else { override = dbOverride } } if override != nil { // override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。 if *override > 0 { count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) if incErr != nil { logger.LegacyPrintf( "service.billing_cache", "Warning: rpm increment (override) failed for user=%d group=%d: %v", user.ID, group.ID, incErr, ) // fail-open } else if count > *override { return ErrGroupRPMExceeded } } // override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。 } else if group.RPMLimit > 0 { // 无 override,检查 group.rpm_limit。 count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) if err != nil { logger.LegacyPrintf( "service.billing_cache", "Warning: rpm increment (group) failed for user=%d group=%d: %v", user.ID, group.ID, err, ) // fail-open } else if count > group.RPMLimit { return ErrGroupRPMExceeded } } } // ── 第二层:用户级全局硬上限(始终生效) ── if user.RPMLimit > 0 { count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID) if err != nil { logger.LegacyPrintf( "service.billing_cache", "Warning: rpm increment (user) failed for user=%d: %v", user.ID, err, ) return nil // fail-open } if count > user.RPMLimit { return ErrUserRPMExceeded } } return nil } // checkBalanceEligibility 检查余额模式资格 func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error { balance, err := s.GetUserBalance(ctx, userID) if err != nil { if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { s.circuitBreaker.OnSuccess() } if balance <= 0 { return ErrInsufficientBalance } return nil } // checkSubscriptionEligibility 检查订阅模式资格 func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error { // 获取订阅缓存数据 subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) if err != nil { if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { s.circuitBreaker.OnSuccess() } // 检查订阅状态 if subData.Status != SubscriptionStatusActive { return ErrSubscriptionInvalid } // 检查是否过期 if time.Now().After(subData.ExpiresAt) { return ErrSubscriptionInvalid } // 检查限额(使用传入的Group限额配置) if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD { return ErrDailyLimitExceeded } if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD { return ErrWeeklyLimitExceeded } if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD { return ErrMonthlyLimitExceeded } return nil } type billingCircuitBreakerState int const ( billingCircuitClosed billingCircuitBreakerState = iota billingCircuitOpen billingCircuitHalfOpen ) type billingCircuitBreaker struct { mu sync.Mutex state billingCircuitBreakerState failures int openedAt time.Time failureThreshold int resetTimeout time.Duration halfOpenRequests int halfOpenRemaining int } func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker { if !cfg.Enabled { return nil } resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second if resetTimeout <= 0 { resetTimeout = 30 * time.Second } halfOpen := cfg.HalfOpenRequests if halfOpen <= 0 { halfOpen = 1 } threshold := cfg.FailureThreshold if threshold <= 0 { threshold = 5 } return &billingCircuitBreaker{ state: billingCircuitClosed, failureThreshold: threshold, resetTimeout: resetTimeout, halfOpenRequests: halfOpen, } } func (b *billingCircuitBreaker) Allow() bool { b.mu.Lock() defer b.mu.Unlock() switch b.state { case billingCircuitClosed: return true case billingCircuitOpen: if time.Since(b.openedAt) < b.resetTimeout { return false } b.state = billingCircuitHalfOpen b.halfOpenRemaining = b.halfOpenRequests logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state") fallthrough case billingCircuitHalfOpen: if b.halfOpenRemaining <= 0 { return false } b.halfOpenRemaining-- return true default: return false } } func (b *billingCircuitBreaker) OnFailure(err error) { if b == nil { return } b.mu.Lock() defer b.mu.Unlock() switch b.state { case billingCircuitOpen: return case billingCircuitHalfOpen: b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err) return default: b.failures++ if b.failures >= b.failureThreshold { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) } } } func (b *billingCircuitBreaker) OnSuccess() { if b == nil { return } b.mu.Lock() defer b.mu.Unlock() previousState := b.state previousFailures := b.failures b.state = billingCircuitClosed b.failures = 0 b.halfOpenRemaining = 0 // 只有状态真正发生变化时才记录日志 if previousState != billingCircuitClosed { logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) } else if previousFailures > 0 { logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures) } } func circuitStateString(state billingCircuitBreakerState) string { switch state { case billingCircuitClosed: return "closed" case billingCircuitOpen: return "open" case billingCircuitHalfOpen: return "half-open" default: return "unknown" } } // checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。 // 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata)。 // checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。 // // 流程(Redis-first / DB-fallback): // 1. 先读 Redis cache;若命中且 SchemaVersion==1,直接用 entry 中的 limits 和 window_start 做校验, // 免除 DB 查询。 // 2. cache MISS 或旧版 entry(SchemaVersion==0)→ 查 DB 回填完整 entry(含 limits/window_start)。 // 3. Redis 故障(err != nil)→ fail-open,查 DB 做一次性检查,不回填。 func (s *BillingCacheService) checkUserPlatformQuotaEligibility( ctx context.Context, userID int64, platform string, ) error { if platform == "" || s.userPlatformQuotaRepo == nil { return nil } // cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。 // 其他 check* 方法(balance/subscription/rate-limit)也有类似守卫。 var ( entry *UserPlatformQuotaCacheEntry ok bool cacheErr error ) if s.cache != nil { entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform) } else { // 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查 cacheErr = errBillingCacheUnavailable } // --- cache HIT with current schema → 直接用 entry,不查 DB --- if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 { now := time.Now() dailyUsage := entry.DailyUsageUSD weeklyUsage := entry.WeeklyUsageUSD monthlyUsage := entry.MonthlyUsageUSD // 若窗口已更新(DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较, // 同时记录新窗口起点用于后续刷新 cache entry。 // 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力, // 持久化数据始终正确。 windowExpired := false newDailyStart := entry.DailyWindowStart newWeeklyStart := entry.WeeklyWindowStart newMonthlyStart := entry.MonthlyWindowStart if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) { dailyUsage = 0 windowExpired = true dayStart := timezone.StartOfDay(now) newDailyStart = &dayStart } if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) { weeklyUsage = 0 windowExpired = true weekStart := timezone.StartOfWeek(now) newWeeklyStart = &weekStart } if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) { monthlyUsage = 0 windowExpired = true monthStart := now newMonthlyStart = &monthStart } // 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis(而非 Delete)。 // 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到 // EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。 // 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。 // 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。 // 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。 // 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。 if windowExpired && s.cache != nil { refreshed := &UserPlatformQuotaCacheEntry{ DailyUsageUSD: dailyUsage, WeeklyUsageUSD: weeklyUsage, MonthlyUsageUSD: monthlyUsage, SchemaVersion: UserPlatformQuotaCacheSchemaV1, DailyLimitUSD: entry.DailyLimitUSD, WeeklyLimitUSD: entry.WeeklyLimitUSD, MonthlyLimitUSD: entry.MonthlyLimitUSD, DailyWindowStart: newDailyStart, WeeklyWindowStart: newWeeklyStart, MonthlyWindowStart: newMonthlyStart, } ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil { logger.LegacyPrintf("service.billing_cache", "Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr) } setCancel() } if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) } if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) } if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now)) } return nil } // --- cache MISS、旧版 entry 或 Redis 故障 → 查 DB(singleflight 合并并发回源)--- // 使用 DoChan 而非 Do:avoid sharing the first caller's ctx among all dedupe followers. // 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。 sfKey := strconv.FormatInt(userID, 10) + ":" + platform ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) { // 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx, // 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。 bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second) defer bgCancel() return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform) }) var ( v any dbErr error ) select { case res := <-ch: v, dbErr = res.Val, res.Err case <-ctx.Done(): // 当前 caller 的 ctx 被取消:fail-open,不阻断 (此请求已无意义)。 logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err()) return nil } if dbErr != nil { logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr) return nil } rec, _ := v.(*UserPlatformQuotaRecord) if rec == nil { return nil } now := time.Now() dailyUsage := rec.DailyUsageUSD weeklyUsage := rec.WeeklyUsageUSD monthlyUsage := rec.MonthlyUsageUSD if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) { dailyUsage = 0 } if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) { weeklyUsage = 0 } if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) { monthlyUsage = 0 } // Redis 故障时 fail-open:不回填,直接用 DB 数据做一次性检查 if cacheErr != nil { if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) } if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) } if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) } return nil } // cache MISS 或旧版 entry → 回填完整 entry(含 limits 和 window_start) newEntry := &UserPlatformQuotaCacheEntry{ DailyUsageUSD: dailyUsage, WeeklyUsageUSD: weeklyUsage, MonthlyUsageUSD: monthlyUsage, SchemaVersion: UserPlatformQuotaCacheSchemaV1, DailyLimitUSD: rec.DailyLimitUSD, WeeklyLimitUSD: rec.WeeklyLimitUSD, MonthlyLimitUSD: rec.MonthlyLimitUSD, DailyWindowStart: rec.DailyWindowStart, WeeklyWindowStart: rec.WeeklyWindowStart, MonthlyWindowStart: rec.MonthlyWindowStart, } if s.cache != nil { ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second // 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms, // 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败, // 让下一次 preflight 仍然 MISS 并击穿到 DB(高并发下增大 DB 压力)。 setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil { logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr) } setCancel() } if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) } if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) } if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) } return nil } // withWindowResetsMetadata 给 quota error 附加 window_resets_at metadata(RFC3339)。 func withWindowResetsMetadata(err error, resetAt time.Time) error { appErr, ok := err.(*infraerrors.ApplicationError) if !ok || appErr == nil { return err } return appErr.WithMetadata(map[string]string{ "window_resets_at": resetAt.Format(time.RFC3339), }) } // nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。 // 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。 func nextDailyReset(now time.Time) time.Time { return timezone.StartOfDay(now).AddDate(0, 0, 1) } // nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。 // 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。 func nextWeeklyReset(now time.Time) time.Time { return timezone.StartOfWeek(now).AddDate(0, 0, 7) } // nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间(start + 30d)。 // start 为 nil(未初始化)或已过期(now-start >= 30d,与 monthlyQuotaWindowExpired 同口径)时 // 退化为 now+30d:过期窗口会在下次 increment 时重置为 now,下次重置即 now+30d; // 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。 func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time { if start == nil || now.Sub(*start) >= 30*24*time.Hour { return now.Add(30 * 24 * time.Hour) } return start.Add(30 * 24 * time.Hour) } // quotaWindowExpired 判断窗口是否已过期:start 为 nil(未初始化)或在 currWindowStart 之前视为已过期。 func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool { if start == nil { return true } return start.Before(currWindowStart) } // monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。 // 过期条件:now - start >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。 // start 为 nil 时视为已过期(未初始化窗口)。 func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool { if start == nil { return true } return now.Sub(*start) >= 30*24*time.Hour }