sub2api/backend/internal/service/billing_cache_service_singleflight_test.go
DaydreamCoding 6b39b344d8 feat(quota): 用户 × 平台 USD 配额
为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月
三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。

两层模型:
- 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/
  dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写
  (系统 1 个 key + 每个来源 1 个 key),整体替换语义。
- 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。

后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository
与 service 端口、计费链路集成、管理端与用户端读写接口。
前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、
中英文案。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-26 10:49:20 +08:00

156 lines
4.2 KiB
Go

//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
func (s *billingCacheMissStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
return nil, false, nil
}
func (s *billingCacheMissStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
return nil
}
func (s *billingCacheMissStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
return nil
}
func (s *billingCacheMissStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
return nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}, nil)
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}