为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月 三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。 两层模型: - 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/ dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写 (系统 1 个 key + 每个来源 1 个 key),整体替换语义。 - 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。 后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository 与 service 端口、计费链路集成、管理端与用户端读写接口。 前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、 中英文案。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
451 lines
11 KiB
Go
451 lines
11 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type redeemCodeRepoStub struct {
|
|
codesByCode map[string]*RedeemCode
|
|
useCalls []struct {
|
|
id int64
|
|
userID int64
|
|
}
|
|
updateCalls []*RedeemCode
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
|
|
panic("unexpected Create call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
|
|
panic("unexpected CreateBatch call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
|
|
panic("unexpected GetByID call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
|
|
if s.codesByCode == nil {
|
|
return nil, ErrRedeemCodeNotFound
|
|
}
|
|
redeemCode, ok := s.codesByCode[code]
|
|
if !ok {
|
|
return nil, ErrRedeemCodeNotFound
|
|
}
|
|
cloned := *redeemCode
|
|
return &cloned, nil
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
|
|
if code == nil {
|
|
return nil
|
|
}
|
|
cloned := *code
|
|
s.updateCalls = append(s.updateCalls, &cloned)
|
|
if s.codesByCode == nil {
|
|
s.codesByCode = make(map[string]*RedeemCode)
|
|
}
|
|
s.codesByCode[cloned.Code] = &cloned
|
|
return nil
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) {
|
|
panic("unexpected BatchUpdate call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
|
|
panic("unexpected Delete call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
|
|
for code, redeemCode := range s.codesByCode {
|
|
if redeemCode.ID != id {
|
|
continue
|
|
}
|
|
now := time.Now().UTC()
|
|
redeemCode.Status = StatusUsed
|
|
redeemCode.UsedBy = &userID
|
|
redeemCode.UsedAt = &now
|
|
s.codesByCode[code] = redeemCode
|
|
s.useCalls = append(s.useCalls, struct {
|
|
id int64
|
|
userID int64
|
|
}{id: id, userID: userID})
|
|
return nil
|
|
}
|
|
return ErrRedeemCodeNotFound
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
|
panic("unexpected List call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
|
panic("unexpected ListWithFilters call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
|
|
panic("unexpected ListByUser call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
|
panic("unexpected ListByUserPaginated call")
|
|
}
|
|
|
|
func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
|
|
panic("unexpected SumPositiveBalanceByUser call")
|
|
}
|
|
|
|
func newOAuthEmailFlowAuthService(
|
|
userRepo UserRepository,
|
|
redeemRepo RedeemCodeRepository,
|
|
refreshTokenCache RefreshTokenCache,
|
|
settings map[string]string,
|
|
emailCache EmailCache,
|
|
quotaRepo UserPlatformQuotaRepository, // 新增
|
|
) *AuthService {
|
|
cfg := &config.Config{
|
|
JWT: config.JWTConfig{
|
|
Secret: "test-secret",
|
|
ExpireHour: 1,
|
|
AccessTokenExpireMinutes: 60,
|
|
RefreshTokenExpireDays: 7,
|
|
},
|
|
Default: config.DefaultConfig{
|
|
UserBalance: 3.5,
|
|
UserConcurrency: 2,
|
|
},
|
|
}
|
|
|
|
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
|
emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
|
|
|
|
return NewAuthService(
|
|
nil,
|
|
userRepo,
|
|
redeemRepo,
|
|
refreshTokenCache,
|
|
cfg,
|
|
settingService,
|
|
emailService,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
quotaRepo, // 替换原来的 nil
|
|
)
|
|
}
|
|
|
|
func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
|
|
userRepo := &userRepoStub{nextID: 42}
|
|
redeemRepo := &redeemCodeRepoStub{
|
|
codesByCode: map[string]*RedeemCode{
|
|
"INVITE123": {
|
|
ID: 7,
|
|
Code: "INVITE123",
|
|
Type: RedeemTypeInvitation,
|
|
Status: StatusUnused,
|
|
},
|
|
},
|
|
}
|
|
emailCache := &emailCacheStub{
|
|
data: &VerificationCodeData{
|
|
Code: "246810",
|
|
Attempts: 0,
|
|
CreatedAt: time.Now().UTC(),
|
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
|
},
|
|
}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
redeemRepo,
|
|
nil,
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyInvitationCodeEnabled: "true",
|
|
SettingKeyEmailVerifyEnabled: "true",
|
|
},
|
|
emailCache,
|
|
nil,
|
|
)
|
|
|
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
|
context.Background(),
|
|
"fresh@example.com",
|
|
"secret-123",
|
|
"246810",
|
|
"INVITE123",
|
|
"oidc",
|
|
)
|
|
|
|
require.Nil(t, tokenPair)
|
|
require.Nil(t, user)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "generate token pair")
|
|
require.Equal(t, []int64{42}, userRepo.deletedIDs)
|
|
require.Len(t, userRepo.created, 1)
|
|
require.Empty(t, redeemRepo.useCalls)
|
|
require.Empty(t, redeemRepo.updateCalls)
|
|
}
|
|
|
|
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
|
|
userRepo := &userRepoStub{nextID: 42}
|
|
emailCache := &emailCacheStub{
|
|
data: &VerificationCodeData{
|
|
Code: "246810",
|
|
Attempts: 0,
|
|
CreatedAt: time.Now().UTC(),
|
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
|
},
|
|
}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
&redeemCodeRepoStub{},
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyEmailVerifyEnabled: "true",
|
|
},
|
|
emailCache,
|
|
nil,
|
|
)
|
|
|
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
|
context.Background(),
|
|
"fresh@example.com",
|
|
"secret-123",
|
|
"246810",
|
|
"",
|
|
" OIDC ",
|
|
)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, tokenPair)
|
|
require.NotNil(t, user)
|
|
require.Len(t, userRepo.created, 1)
|
|
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
|
}
|
|
|
|
func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
signupSource string
|
|
want string
|
|
}{
|
|
{
|
|
name: "github",
|
|
email: "github@example.com",
|
|
signupSource: " GitHub ",
|
|
want: "github",
|
|
},
|
|
{
|
|
name: "google",
|
|
email: "google@example.com",
|
|
signupSource: " Google ",
|
|
want: "google",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
userRepo := &userRepoStub{nextID: 43}
|
|
emailCache := &emailCacheStub{
|
|
data: &VerificationCodeData{
|
|
Code: "246810",
|
|
Attempts: 0,
|
|
CreatedAt: time.Now().UTC(),
|
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
|
},
|
|
}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
&redeemCodeRepoStub{},
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyEmailVerifyEnabled: "true",
|
|
},
|
|
emailCache,
|
|
nil,
|
|
)
|
|
|
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
|
context.Background(),
|
|
tt.email,
|
|
"secret-123",
|
|
"246810",
|
|
"",
|
|
tt.signupSource,
|
|
)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, tokenPair)
|
|
require.NotNil(t, user)
|
|
require.Len(t, userRepo.created, 1)
|
|
require.Equal(t, tt.want, userRepo.created[0].SignupSource)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
|
userRepo := &userRepoStub{nextID: 43}
|
|
emailCache := &emailCacheStub{
|
|
data: &VerificationCodeData{
|
|
Code: "246810",
|
|
Attempts: 0,
|
|
CreatedAt: time.Now().UTC(),
|
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
|
},
|
|
}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
&redeemCodeRepoStub{},
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyEmailVerifyEnabled: "true",
|
|
},
|
|
emailCache,
|
|
nil,
|
|
)
|
|
|
|
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
|
context.Background(),
|
|
"fallback@example.com",
|
|
"secret-123",
|
|
"246810",
|
|
"",
|
|
"unknown-provider",
|
|
)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, tokenPair)
|
|
require.NotNil(t, user)
|
|
require.Len(t, userRepo.created, 1)
|
|
require.Equal(t, "email", userRepo.created[0].SignupSource)
|
|
}
|
|
|
|
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
|
|
userRepo := &userRepoStub{}
|
|
redeemRepo := &redeemCodeRepoStub{
|
|
codesByCode: map[string]*RedeemCode{
|
|
"INVITE123": {
|
|
ID: 7,
|
|
Code: "INVITE123",
|
|
Type: RedeemTypeInvitation,
|
|
Status: StatusUsed,
|
|
UsedBy: func() *int64 {
|
|
v := int64(42)
|
|
return &v
|
|
}(),
|
|
UsedAt: func() *time.Time {
|
|
v := time.Now().UTC()
|
|
return &v
|
|
}(),
|
|
},
|
|
},
|
|
}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
redeemRepo,
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyInvitationCodeEnabled: "true",
|
|
},
|
|
&emailCacheStub{},
|
|
nil,
|
|
)
|
|
|
|
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
|
|
|
|
require.NoError(t, err)
|
|
require.Equal(t, []int64{42}, userRepo.deletedIDs)
|
|
require.Len(t, redeemRepo.updateCalls, 1)
|
|
require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
|
|
require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
|
|
require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
|
|
}
|
|
|
|
func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
|
|
userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
&redeemCodeRepoStub{},
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
},
|
|
&emailCacheStub{},
|
|
nil,
|
|
)
|
|
|
|
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
|
|
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "delete created oauth user")
|
|
}
|
|
|
|
func TestFinalizeOAuthEmailAccount_SnapshotsPlatformQuotaDefaults(t *testing.T) {
|
|
userRepo := &userRepoStub{nextID: 99}
|
|
quotaRepo := &userPlatformQuotaRepoStub{}
|
|
|
|
authService := newOAuthEmailFlowAuthService(
|
|
userRepo,
|
|
nil,
|
|
&refreshTokenCacheStub{},
|
|
map[string]string{
|
|
SettingKeyRegistrationEnabled: "true",
|
|
SettingKeyEmailVerifyEnabled: "true",
|
|
SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 5.5}}`,
|
|
},
|
|
&emailCacheStub{},
|
|
quotaRepo,
|
|
)
|
|
|
|
user := &User{
|
|
ID: 99,
|
|
Email: "newuser@example.com",
|
|
Role: RoleUser,
|
|
Status: StatusActive,
|
|
SignupSource: "oidc",
|
|
}
|
|
|
|
err := authService.FinalizeOAuthEmailAccount(
|
|
context.Background(),
|
|
user,
|
|
"",
|
|
"oidc",
|
|
"",
|
|
)
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, quotaRepo.bulkInsertCalls, 1, "snapshotPlatformQuotaDefaults must call BulkInsertInitial once on successful OAuth signup")
|
|
|
|
records := quotaRepo.bulkInsertCalls[0]
|
|
var anthropicRecord *UserPlatformQuotaRecord
|
|
for i := range records {
|
|
if records[i].Platform == "anthropic" {
|
|
anthropicRecord = &records[i]
|
|
break
|
|
}
|
|
}
|
|
require.NotNil(t, anthropicRecord, "expected anthropic platform record")
|
|
require.Equal(t, int64(99), anthropicRecord.UserID)
|
|
require.NotNil(t, anthropicRecord.DailyLimitUSD)
|
|
require.InDelta(t, 5.5, *anthropicRecord.DailyLimitUSD, 0.0001)
|
|
}
|