532 lines
17 KiB
Go
532 lines
17 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"crypto/sha256"
|
||
"encoding/binary"
|
||
"encoding/hex"
|
||
"os"
|
||
"strconv"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"golang.org/x/sync/singleflight"
|
||
)
|
||
|
||
// ConcurrencyCache 定义并发控制的缓存接口
|
||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||
type ConcurrencyCache interface {
|
||
// 账号槽位管理
|
||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||
|
||
// 账号等待队列(账号级)
|
||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||
|
||
// 用户槽位管理
|
||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||
|
||
// 等待队列计数(只在首次创建时设置 TTL)
|
||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||
|
||
// 批量负载查询(只读)
|
||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
|
||
|
||
// 清理过期槽位(后台任务)
|
||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||
|
||
// 启动时清理旧进程遗留槽位与等待计数
|
||
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
|
||
}
|
||
|
||
var (
|
||
requestIDPrefix = initRequestIDPrefix()
|
||
requestIDCounter atomic.Uint64
|
||
)
|
||
|
||
func initRequestIDPrefix() string {
|
||
b := make([]byte, 8)
|
||
if _, err := rand.Read(b); err == nil {
|
||
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
|
||
}
|
||
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
|
||
return "r" + strconv.FormatUint(fallback, 36)
|
||
}
|
||
|
||
func RequestIDPrefix() string {
|
||
return requestIDPrefix
|
||
}
|
||
|
||
func generateRequestID() string {
|
||
seq := requestIDCounter.Add(1)
|
||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||
}
|
||
|
||
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
|
||
if s == nil || s.cache == nil {
|
||
return nil
|
||
}
|
||
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
|
||
}
|
||
|
||
const (
|
||
// 默认等待队列额外槽位
|
||
defaultExtraWaitSlots = 20
|
||
|
||
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
|
||
accountLoadBatchFetchTimeout = 3 * time.Second
|
||
maxAccountLoadBatchCacheEntries = 256
|
||
)
|
||
|
||
// ConcurrencyService 管理账号和用户的并发限制。
|
||
type ConcurrencyService struct {
|
||
cache ConcurrencyCache
|
||
|
||
accountLoadCacheTTL atomic.Int64
|
||
accountLoadCacheMu sync.RWMutex
|
||
accountLoadCache map[string]cachedAccountLoadBatch
|
||
accountLoadGroup singleflight.Group
|
||
}
|
||
|
||
type cachedAccountLoadBatch struct {
|
||
loadMap map[int64]*AccountLoadInfo
|
||
expiresAt time.Time
|
||
}
|
||
|
||
// NewConcurrencyService 创建并发控制服务。
|
||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||
svc := &ConcurrencyService{
|
||
cache: cache,
|
||
accountLoadCache: make(map[string]cachedAccountLoadBatch),
|
||
}
|
||
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
|
||
return svc
|
||
}
|
||
|
||
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
|
||
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
|
||
if s == nil {
|
||
return
|
||
}
|
||
s.accountLoadCacheTTL.Store(int64(ttl))
|
||
if ttl <= 0 {
|
||
s.accountLoadCacheMu.Lock()
|
||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||
s.accountLoadCacheMu.Unlock()
|
||
}
|
||
}
|
||
|
||
// AcquireResult represents the result of acquiring a concurrency slot
|
||
type AcquireResult struct {
|
||
Acquired bool
|
||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||
}
|
||
|
||
type AccountWithConcurrency struct {
|
||
ID int64
|
||
MaxConcurrency int
|
||
}
|
||
|
||
type UserWithConcurrency struct {
|
||
ID int64
|
||
MaxConcurrency int
|
||
}
|
||
|
||
type AccountLoadInfo struct {
|
||
AccountID int64
|
||
CurrentConcurrency int
|
||
WaitingCount int
|
||
LoadRate int // 0-100+ (percent)
|
||
}
|
||
|
||
type UserLoadInfo struct {
|
||
UserID int64
|
||
CurrentConcurrency int
|
||
WaitingCount int
|
||
LoadRate int // 0-100+ (percent)
|
||
}
|
||
|
||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||
// Returns a release function that MUST be called when the request completes.
|
||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||
// If maxConcurrency is 0 or negative, no limit
|
||
if maxConcurrency <= 0 {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {}, // no-op
|
||
}, nil
|
||
}
|
||
|
||
// Generate unique request ID for this slot
|
||
requestID := generateRequestID()
|
||
|
||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if acquired {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||
}
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
return &AcquireResult{
|
||
Acquired: false,
|
||
ReleaseFunc: nil,
|
||
}, nil
|
||
}
|
||
|
||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||
// Returns a release function that MUST be called when the request completes.
|
||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||
// If maxConcurrency is 0 or negative, no limit
|
||
if maxConcurrency <= 0 {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {}, // no-op
|
||
}, nil
|
||
}
|
||
|
||
// Generate unique request ID for this slot
|
||
requestID := generateRequestID()
|
||
|
||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if acquired {
|
||
return &AcquireResult{
|
||
Acquired: true,
|
||
ReleaseFunc: func() {
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||
}
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
return &AcquireResult{
|
||
Acquired: false,
|
||
ReleaseFunc: nil,
|
||
}, nil
|
||
}
|
||
|
||
// ============================================
|
||
// Wait Queue Count Methods
|
||
// ============================================
|
||
|
||
// IncrementWaitCount attempts to increment the wait queue counter for a user.
|
||
// Returns true if successful, false if the wait queue is full.
|
||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||
if s.cache == nil {
|
||
// Redis not available, allow request
|
||
return true, nil
|
||
}
|
||
|
||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||
if err != nil {
|
||
// On error, allow the request to proceed (fail open)
|
||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
|
||
return true, nil
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||
// Should be called when a request completes or exits the wait queue.
|
||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
// Use background context to ensure decrement even if original context is cancelled
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
|
||
}
|
||
}
|
||
|
||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||
if s.cache == nil {
|
||
return true, nil
|
||
}
|
||
|
||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
|
||
return true, nil
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||
if s.cache == nil {
|
||
return
|
||
}
|
||
|
||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||
}
|
||
}
|
||
|
||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||
if s.cache == nil {
|
||
return 0, nil
|
||
}
|
||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||
}
|
||
|
||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||
func CalculateMaxWait(userConcurrency int) int {
|
||
if userConcurrency <= 0 {
|
||
userConcurrency = 1
|
||
}
|
||
return userConcurrency + defaultExtraWaitSlots
|
||
}
|
||
|
||
// GetAccountsLoadBatch 批量获取账号负载信息。
|
||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||
return s.getAccountsLoadBatch(ctx, accounts, true)
|
||
}
|
||
|
||
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
|
||
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||
return s.getAccountsLoadBatch(ctx, accounts, false)
|
||
}
|
||
|
||
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
|
||
if len(accounts) == 0 {
|
||
return map[int64]*AccountLoadInfo{}, nil
|
||
}
|
||
if s.cache == nil {
|
||
return map[int64]*AccountLoadInfo{}, nil
|
||
}
|
||
|
||
ttl := time.Duration(s.accountLoadCacheTTL.Load())
|
||
if !allowCache || ttl <= 0 {
|
||
return s.fetchAccountsLoadBatch(ctx, accounts)
|
||
}
|
||
|
||
key := accountLoadBatchCacheKey(accounts)
|
||
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
|
||
return cached, nil
|
||
}
|
||
|
||
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
|
||
now := time.Now()
|
||
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
|
||
return cached, nil
|
||
}
|
||
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
|
||
if fetchErr != nil {
|
||
return nil, fetchErr
|
||
}
|
||
cached := cloneAccountLoadMap(loadMap)
|
||
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
|
||
return cached, nil
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
loadMap, _ := value.(map[int64]*AccountLoadInfo)
|
||
if loadMap == nil {
|
||
return map[int64]*AccountLoadInfo{}, nil
|
||
}
|
||
return loadMap, nil
|
||
}
|
||
|
||
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||
if s.cache == nil {
|
||
return map[int64]*AccountLoadInfo{}, nil
|
||
}
|
||
baseCtx := context.Background()
|
||
if ctx != nil {
|
||
baseCtx = context.WithoutCancel(ctx)
|
||
}
|
||
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
|
||
defer cancel()
|
||
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
|
||
}
|
||
|
||
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
|
||
s.accountLoadCacheMu.RLock()
|
||
cached, ok := s.accountLoadCache[key]
|
||
s.accountLoadCacheMu.RUnlock()
|
||
if !ok {
|
||
return nil, false
|
||
}
|
||
if !now.Before(cached.expiresAt) {
|
||
s.accountLoadCacheMu.Lock()
|
||
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
|
||
delete(s.accountLoadCache, key)
|
||
}
|
||
s.accountLoadCacheMu.Unlock()
|
||
return nil, false
|
||
}
|
||
return cached.loadMap, true
|
||
}
|
||
|
||
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
|
||
s.accountLoadCacheMu.Lock()
|
||
if s.accountLoadCache == nil {
|
||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||
}
|
||
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||
now := time.Now()
|
||
for cacheKey, cached := range s.accountLoadCache {
|
||
if !now.Before(cached.expiresAt) {
|
||
delete(s.accountLoadCache, cacheKey)
|
||
}
|
||
}
|
||
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||
for cacheKey := range s.accountLoadCache {
|
||
delete(s.accountLoadCache, cacheKey)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
s.accountLoadCache[key] = cachedAccountLoadBatch{
|
||
loadMap: loadMap,
|
||
expiresAt: expiresAt,
|
||
}
|
||
s.accountLoadCacheMu.Unlock()
|
||
}
|
||
|
||
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
|
||
hash := sha256.New()
|
||
var buf [16]byte
|
||
for _, account := range accounts {
|
||
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
|
||
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
|
||
_, _ = hash.Write(buf[:])
|
||
}
|
||
sum := hash.Sum(nil)
|
||
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
|
||
}
|
||
|
||
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
|
||
if len(loadMap) == 0 {
|
||
return map[int64]*AccountLoadInfo{}
|
||
}
|
||
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
|
||
for accountID, loadInfo := range loadMap {
|
||
if loadInfo == nil {
|
||
clone[accountID] = nil
|
||
continue
|
||
}
|
||
copied := *loadInfo
|
||
clone[accountID] = &copied
|
||
}
|
||
return clone
|
||
}
|
||
|
||
// GetUsersLoadBatch returns load info for multiple users.
|
||
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||
if s.cache == nil {
|
||
return map[int64]*UserLoadInfo{}, nil
|
||
}
|
||
return s.cache.GetUsersLoadBatch(ctx, users)
|
||
}
|
||
|
||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||
if s.cache == nil {
|
||
return nil
|
||
}
|
||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||
}
|
||
|
||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||
return
|
||
}
|
||
|
||
runCleanup := func() {
|
||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||
cancel()
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
|
||
return
|
||
}
|
||
for _, account := range accounts {
|
||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||
accountCancel()
|
||
if err != nil {
|
||
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
go func() {
|
||
ticker := time.NewTicker(interval)
|
||
defer ticker.Stop()
|
||
|
||
runCleanup()
|
||
for range ticker.C {
|
||
runCleanup()
|
||
}
|
||
}()
|
||
}
|
||
|
||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
|
||
// Uses a detached context with timeout to prevent HTTP request cancellation from
|
||
// causing the entire batch to fail (which would show all concurrency as 0).
|
||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||
if len(accountIDs) == 0 {
|
||
return map[int64]int{}, nil
|
||
}
|
||
if s.cache == nil {
|
||
result := make(map[int64]int, len(accountIDs))
|
||
for _, accountID := range accountIDs {
|
||
result[accountID] = 0
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// Use a detached context so that a cancelled HTTP request doesn't cause
|
||
// the Redis pipeline to fail and return all-zero concurrency counts.
|
||
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer cancel()
|
||
|
||
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
|
||
}
|