feat(quota): 用户 × 平台 USD 配额

为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月
三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。

两层模型:
- 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/
  dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写
  (系统 1 个 key + 每个来源 1 个 key),整体替换语义。
- 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。

后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository
与 service 端口、计费链路集成、管理端与用户端读写接口。
前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、
中英文案。

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
DaydreamCoding 2026-05-25 14:26:11 +08:00 committed by QTom
parent 2f70d965bf
commit 6b39b344d8
123 changed files with 14220 additions and 232 deletions

View File

@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
userRPMCache := repository.NewUserRPMCache(redisClient)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client)
serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository)
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
affiliateRepository := repository.NewAffiliateRepository(client, db)
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
@ -190,7 +192,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsRepository := repository.NewOpsRepository(db)
identityService := service.NewIdentityService(identityCache)
digestSessionStore := service.NewDigestSessionStore()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, serviceUserPlatformQuotaRepository)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)

View File

@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
pricingSvc := service.NewPricingService(cfg, nil)
emailQueueSvc := service.NewEmailQueueService(nil, 1)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)

View File

@ -48,6 +48,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
stdsql "database/sql"
@ -124,6 +125,8 @@ type Client struct {
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
UserPlatformQuota *UserPlatformQuotaClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
}
@ -170,6 +173,7 @@ func (c *Client) init() {
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config)
c.UserSubscription = NewUserSubscriptionClient(c.config)
}
@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
UserAttributeValue: NewUserAttributeValueClient(cfg),
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) {
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.UserPlatformQuota, c.UserSubscription,
} {
n.Use(hooks...)
}
@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
c.UserPlatformQuota, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.UserAttributeDefinition.mutate(ctx, m)
case *UserAttributeValueMutation:
return c.UserAttributeValue.mutate(ctx, m)
case *UserPlatformQuotaMutation:
return c.UserPlatformQuota.mutate(ctx, m)
case *UserSubscriptionMutation:
return c.UserSubscription.mutate(ctx, m)
default:
@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery
return query
}
// QueryPlatformQuotas queries the platform_quotas edge of a User.
func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery {
query := (&UserPlatformQuotaClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV
}
}
// UserPlatformQuotaClient is a client for the UserPlatformQuota schema.
type UserPlatformQuotaClient struct {
config
}
// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config.
func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient {
return &UserPlatformQuotaClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`.
func (c *UserPlatformQuotaClient) Use(hooks ...Hook) {
c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`.
func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) {
c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...)
}
// Create returns a builder for creating a UserPlatformQuota entity.
func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate {
mutation := newUserPlatformQuotaMutation(c.config, OpCreate)
return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities.
func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk {
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*UserPlatformQuotaCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdate)
return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m))
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne {
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id))
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete {
mutation := newUserPlatformQuotaMutation(c.config, OpDelete)
return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne {
builder := c.Delete().Where(userplatformquota.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &UserPlatformQuotaDeleteOne{builder}
}
// Query returns a query builder for UserPlatformQuota.
func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery {
return &UserPlatformQuotaQuery{
config: c.config,
ctx: &QueryContext{Type: TypeUserPlatformQuota},
inters: c.Interceptors(),
}
}
// Get returns a UserPlatformQuota entity by its id.
func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) {
return c.Query().Where(userplatformquota.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// QueryUser queries the user edge of a UserPlatformQuota.
func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery {
query := (&UserClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID
step := sqlgraph.NewStep(
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
)
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil
}
return query
}
// Hooks returns the client hooks.
func (c *UserPlatformQuotaClient) Hooks() []Hook {
hooks := c.hooks.UserPlatformQuota
return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...)
}
// Interceptors returns the client interceptors.
func (c *UserPlatformQuotaClient) Interceptors() []Interceptor {
inters := c.inters.UserPlatformQuota
return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...)
}
func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op())
}
}
// UserSubscriptionClient is a client for the UserSubscription schema.
type UserSubscriptionClient struct {
config
@ -6025,7 +6200,8 @@ type (
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
@ -6035,7 +6211,8 @@ type (
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
UserSubscription []ent.Interceptor
}
)

View File

@ -45,6 +45,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -139,6 +140,7 @@ func checkColumn(t, c string) error {
userallowedgroup.Table: userallowedgroup.ValidColumn,
userattributedefinition.Table: userattributedefinition.ValidColumn,
userattributevalue.Table: userattributevalue.ValidColumn,
userplatformquota.Table: userplatformquota.ValidColumn,
usersubscription.Table: usersubscription.ValidColumn,
})
})

View File

@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
}
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary
// function as UserPlatformQuota mutator.
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
// function as UserSubscription mutator.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)

View File

@ -42,6 +42,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
}
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error)
// Query calls f(ctx, q).
func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
return f(ctx, q)
}
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
}
// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser.
type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier {
return next
}
// Traverse calls f(ctx, q).
func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error {
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
return f(ctx, q)
}
return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
}
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
case *ent.UserAttributeValueQuery:
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
case *ent.UserPlatformQuotaQuery:
return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil
case *ent.UserSubscriptionQuery:
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
default:

View File

@ -1612,6 +1612,53 @@ var (
},
},
}
// UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table.
UserPlatformQuotasColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "platform", Type: field.TypeString, Size: 32},
{Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "user_id", Type: field.TypeInt64},
}
// UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table.
UserPlatformQuotasTable = &schema.Table{
Name: "user_platform_quotas",
Columns: UserPlatformQuotasColumns,
PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "user_platform_quotas_users_platform_quotas",
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
},
Indexes: []*schema.Index{
{
Name: "userplatformquota_user_id_platform",
Unique: true,
Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]},
Annotation: &entsql.IndexAnnotation{
Where: "deleted_at IS NULL",
},
},
{
Name: "userplatformquota_user_id",
Unique: false,
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
},
},
}
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
UserSubscriptionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@ -1736,6 +1783,7 @@ var (
UserAllowedGroupsTable,
UserAttributeDefinitionsTable,
UserAttributeValuesTable,
UserPlatformQuotasTable,
UserSubscriptionsTable,
}
)
@ -1869,6 +1917,10 @@ func init() {
UserAttributeValuesTable.Annotation = &entsql.Annotation{
Table: "user_attribute_values",
}
UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable
UserPlatformQuotasTable.Annotation = &entsql.Annotation{
Table: "user_platform_quotas",
}
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable

File diff suppressed because it is too large Load Diff

View File

@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector)
// UserAttributeValue is the predicate function for userattributevalue builders.
type UserAttributeValue func(*sql.Selector)
// UserPlatformQuota is the predicate function for userplatformquota builders.
type UserPlatformQuota func(*sql.Selector)
// UserSubscription is the predicate function for usersubscription builders.
type UserSubscription func(*sql.Selector)

View File

@ -39,6 +39,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
@ -1997,6 +1998,56 @@ func init() {
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
// userattributevalue.DefaultValue holds the default value on creation for the value field.
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin()
userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks()
userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0]
userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors()
userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0]
userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields()
_ = userplatformquotaMixinFields0
userplatformquotaFields := schema.UserPlatformQuota{}.Fields()
_ = userplatformquotaFields
// userplatformquotaDescCreatedAt is the schema descriptor for created_at field.
userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor()
// userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field.
userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time)
// userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field.
userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor()
// userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field.
userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time)
// userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time)
// userplatformquotaDescPlatform is the schema descriptor for platform field.
userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor()
// userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
userplatformquota.PlatformValidator = func() func(string) error {
validators := userplatformquotaDescPlatform.Validators
fns := [...]func(string) error{
validators[0].(func(string) error),
validators[1].(func(string) error),
validators[2].(func(string) error),
}
return func(platform string) error {
for _, fn := range fns {
if err := fn(platform); err != nil {
return err
}
}
return nil
}
}()
// userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field.
userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor()
// userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field.
userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64)
// userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field.
userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor()
// userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field.
userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64)
// userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field.
userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor()
// userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field.
userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64)
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]

View File

@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge {
edge.To("auth_identities", AuthIdentity.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
edge.To("platform_quotas", UserPlatformQuota.Type),
}
}

View File

@ -0,0 +1,113 @@
package schema
import (
"fmt"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/entsql"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/index"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
)
// UserPlatformQuota holds the schema definition for per-user per-platform quota.
type UserPlatformQuota struct {
ent.Schema
}
func (UserPlatformQuota) Annotations() []schema.Annotation {
return []schema.Annotation{
entsql.Annotation{Table: "user_platform_quotas"},
}
}
func (UserPlatformQuota) Mixin() []ent.Mixin {
return []ent.Mixin{
mixins.TimeMixin{},
mixins.SoftDeleteMixin{},
}
}
func (UserPlatformQuota) Fields() []ent.Field {
return []ent.Field{
field.Int64("user_id"),
field.String("platform").
MaxLen(32).
NotEmpty().
Validate(func(s string) error {
// 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms
// 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。
switch s {
case "anthropic", "openai", "gemini", "antigravity":
return nil
default:
return fmt.Errorf("platform %q is not allowed", s)
}
}),
// 日 / 周 / 月 USD 上限:
// nil / not set → 无限额(完全放行)
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
// > 0 → USD 限额上限
field.Float("daily_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("weekly_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("monthly_limit_usd").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
// 当前窗口已用量USDpreflight 时与 limit 比较)
field.Float("daily_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("weekly_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
field.Float("monthly_usage_usd").
Default(0).
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
// 窗口起点NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底)
field.Time("daily_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("weekly_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("monthly_window_start").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
}
}
func (UserPlatformQuota) Edges() []ent.Edge {
return []ent.Edge{
edge.From("user", User.Type).
Ref("platform_quotas").
Field("user_id").
Unique().
Required(),
}
}
func (UserPlatformQuota) Indexes() []ent.Index {
return []ent.Index{
// 软删除友好:只对未删记录唯一
index.Fields("user_id", "platform").
Unique().
Annotations(entsql.IndexWhere("deleted_at IS NULL")),
index.Fields("user_id"),
}
}

View File

@ -80,6 +80,8 @@ type Tx struct {
UserAttributeDefinition *UserAttributeDefinitionClient
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
UserAttributeValue *UserAttributeValueClient
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
UserPlatformQuota *UserPlatformQuotaClient
// UserSubscription is the client for interacting with the UserSubscription builders.
UserSubscription *UserSubscriptionClient
@ -246,6 +248,7 @@ func (tx *Tx) init() {
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config)
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
}

View File

@ -95,11 +95,13 @@ type UserEdges struct {
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// PlatformQuotas holds the value of the platform_quotas edge.
PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [13]bool
loadedTypes [14]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
}
// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) {
if e.loadedTypes[12] {
return e.PlatformQuotas, nil
}
return nil, &NotLoadedError{edge: "platform_quotas"}
}
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
if e.loadedTypes[12] {
if e.loadedTypes[13] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
}
// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity.
func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery {
return NewUserClient(_m.config).QueryPlatformQuotas(_m)
}
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)

View File

@ -85,6 +85,8 @@ const (
EdgeAuthIdentities = "auth_identities"
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
EdgePendingAuthSessions = "pending_auth_sessions"
// EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations.
EdgePlatformQuotas = "platform_quotas"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@ -171,6 +173,13 @@ const (
PendingAuthSessionsInverseTable = "pending_auth_sessions"
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
PendingAuthSessionsColumn = "target_user_id"
// PlatformQuotasTable is the table that holds the platform_quotas relation/edge.
PlatformQuotasTable = "user_platform_quotas"
// PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity.
// It exists in this package in order to avoid circular dependency with the "userplatformquota" package.
PlatformQuotasInverseTable = "user_platform_quotas"
// PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge.
PlatformQuotasColumn = "user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti
}
}
// ByPlatformQuotasCount orders the results by platform_quotas count.
func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...)
}
}
// ByPlatformQuotas orders the results by platform_quotas terms.
func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
)
}
func newPlatformQuotasStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(PlatformQuotasInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
)
}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),

View File

@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate
})
}
// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge.
func HasPlatformQuotas() predicate.User {
return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates).
func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User {
return predicate.User(func(s *sql.Selector) {
step := newPlatformQuotasStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {

View File

@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea
return _c.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate {
_c.mutation.AddPlatformQuotaIDs(ids...)
return _c
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges = append(_spec.Edges, edge)
}
return _node, _spec
}

View File

@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -48,6 +49,7 @@ type UserQuery struct {
withPaymentOrders *PaymentOrderQuery
withAuthIdentities *AuthIdentityQuery
withPendingAuthSessions *PendingAuthSessionQuery
withPlatformQuotas *UserPlatformQuotaQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
return query
}
// QueryPlatformQuotas chains the current query on the "platform_quotas" edge.
func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery {
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery {
withPaymentOrders: _q.withPaymentOrders.Clone(),
withAuthIdentities: _q.withAuthIdentities.Clone(),
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withPlatformQuotas: _q.withPlatformQuotas.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue
return _q
}
// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to
// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery {
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withPlatformQuotas = query
return _q
}
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
loadedTypes = [13]bool{
loadedTypes = [14]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withPaymentOrders != nil,
_q.withAuthIdentities != nil,
_q.withPendingAuthSessions != nil,
_q.withPlatformQuotas != nil,
_q.withUserAllowedGroups != nil,
}
)
@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
if query := _q.withPlatformQuotas; query != nil {
if err := _q.loadPlatformQuotas(ctx, query, nodes,
func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} },
func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil {
return nil, err
}
}
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending
}
return nil
}
func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(userplatformquota.FieldUserID)
}
query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) {
s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
fk := n.UserID
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
return nil
}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)

View File

@ -23,6 +23,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
)
@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda
return _u.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate {
_u.mutation.AddPlatformQuotaIDs(ids...)
return _u
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU
return _u.RemovePendingAuthSessionIDs(ids...)
}
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate {
_u.mutation.ClearPlatformQuotas()
return _u
}
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate {
_u.mutation.RemovePlatformQuotaIDs(ids...)
return _u
}
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePlatformQuotaIDs(ids...)
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU
return _u.AddPendingAuthSessionIDs(ids...)
}
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddPlatformQuotaIDs(ids...)
return _u
}
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddPlatformQuotaIDs(ids...)
}
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us
return _u.RemovePendingAuthSessionIDs(ids...)
}
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne {
_u.mutation.ClearPlatformQuotas()
return _u
}
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne {
_u.mutation.RemovePlatformQuotaIDs(ids...)
return _u
}
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
ids := make([]int64, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemovePlatformQuotaIDs(ids...)
}
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.PlatformQuotasTable,
Columns: []string{user.PlatformQuotasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

View File

@ -0,0 +1,301 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuota is the model entity for the UserPlatformQuota schema.
type UserPlatformQuota struct {
config `json:"-"`
// ID of the ent.
ID int64 `json:"id,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field.
UpdatedAt time.Time `json:"updated_at,omitempty"`
// DeletedAt holds the value of the "deleted_at" field.
DeletedAt *time.Time `json:"deleted_at,omitempty"`
// UserID holds the value of the "user_id" field.
UserID int64 `json:"user_id,omitempty"`
// Platform holds the value of the "platform" field.
Platform string `json:"platform,omitempty"`
// DailyLimitUsd holds the value of the "daily_limit_usd" field.
DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"`
// WeeklyLimitUsd holds the value of the "weekly_limit_usd" field.
WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"`
// MonthlyLimitUsd holds the value of the "monthly_limit_usd" field.
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DailyUsageUsd holds the value of the "daily_usage_usd" field.
DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"`
// WeeklyUsageUsd holds the value of the "weekly_usage_usd" field.
WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"`
// MonthlyUsageUsd holds the value of the "monthly_usage_usd" field.
MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"`
// DailyWindowStart holds the value of the "daily_window_start" field.
DailyWindowStart *time.Time `json:"daily_window_start,omitempty"`
// WeeklyWindowStart holds the value of the "weekly_window_start" field.
WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"`
// MonthlyWindowStart holds the value of the "monthly_window_start" field.
MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserPlatformQuotaQuery when eager-loading is set.
Edges UserPlatformQuotaEdges `json:"edges"`
selectValues sql.SelectValues
}
// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph.
type UserPlatformQuotaEdges struct {
// User holds the value of the user edge.
User *User `json:"user,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
loadedTypes [1]bool
}
// UserOrErr returns the User value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found.
func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) {
if e.User != nil {
return e.User, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: user.Label}
}
return nil, &NotLoadedError{edge: "user"}
}
// scanValues returns the types for scanning values from sql.Rows.
func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd:
values[i] = new(sql.NullFloat64)
case userplatformquota.FieldID, userplatformquota.FieldUserID:
values[i] = new(sql.NullInt64)
case userplatformquota.FieldPlatform:
values[i] = new(sql.NullString)
case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the UserPlatformQuota fields.
func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case userplatformquota.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
_m.ID = int64(value.Int64)
case userplatformquota.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case userplatformquota.FieldUpdatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
} else if value.Valid {
_m.UpdatedAt = value.Time
}
case userplatformquota.FieldDeletedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
} else if value.Valid {
_m.DeletedAt = new(time.Time)
*_m.DeletedAt = value.Time
}
case userplatformquota.FieldUserID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field user_id", values[i])
} else if value.Valid {
_m.UserID = value.Int64
}
case userplatformquota.FieldPlatform:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field platform", values[i])
} else if value.Valid {
_m.Platform = value.String
}
case userplatformquota.FieldDailyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i])
} else if value.Valid {
_m.DailyLimitUsd = new(float64)
*_m.DailyLimitUsd = value.Float64
}
case userplatformquota.FieldWeeklyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i])
} else if value.Valid {
_m.WeeklyLimitUsd = new(float64)
*_m.WeeklyLimitUsd = value.Float64
}
case userplatformquota.FieldMonthlyLimitUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i])
} else if value.Valid {
_m.MonthlyLimitUsd = new(float64)
*_m.MonthlyLimitUsd = value.Float64
}
case userplatformquota.FieldDailyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i])
} else if value.Valid {
_m.DailyUsageUsd = value.Float64
}
case userplatformquota.FieldWeeklyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i])
} else if value.Valid {
_m.WeeklyUsageUsd = value.Float64
}
case userplatformquota.FieldMonthlyUsageUsd:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i])
} else if value.Valid {
_m.MonthlyUsageUsd = value.Float64
}
case userplatformquota.FieldDailyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field daily_window_start", values[i])
} else if value.Valid {
_m.DailyWindowStart = new(time.Time)
*_m.DailyWindowStart = value.Time
}
case userplatformquota.FieldWeeklyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i])
} else if value.Valid {
_m.WeeklyWindowStart = new(time.Time)
*_m.WeeklyWindowStart = value.Time
}
case userplatformquota.FieldMonthlyWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i])
} else if value.Valid {
_m.MonthlyWindowStart = new(time.Time)
*_m.MonthlyWindowStart = value.Time
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota.
// This includes values selected through modifiers, order, etc.
func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// QueryUser queries the "user" edge of the UserPlatformQuota entity.
func (_m *UserPlatformQuota) QueryUser() *UserQuery {
return NewUserPlatformQuotaClient(_m.config).QueryUser(_m)
}
// Update returns a builder for updating this UserPlatformQuota.
// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne {
return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("ent: UserPlatformQuota is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *UserPlatformQuota) String() string {
var builder strings.Builder
builder.WriteString("UserPlatformQuota(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("updated_at=")
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.DeletedAt; v != nil {
builder.WriteString("deleted_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
builder.WriteString("user_id=")
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
builder.WriteString(", ")
builder.WriteString("platform=")
builder.WriteString(_m.Platform)
builder.WriteString(", ")
if v := _m.DailyLimitUsd; v != nil {
builder.WriteString("daily_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.WeeklyLimitUsd; v != nil {
builder.WriteString("weekly_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
if v := _m.MonthlyLimitUsd; v != nil {
builder.WriteString("monthly_limit_usd=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("daily_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd))
builder.WriteString(", ")
builder.WriteString("weekly_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd))
builder.WriteString(", ")
builder.WriteString("monthly_usage_usd=")
builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd))
builder.WriteString(", ")
if v := _m.DailyWindowStart; v != nil {
builder.WriteString("daily_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.WeeklyWindowStart; v != nil {
builder.WriteString("weekly_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.MonthlyWindowStart; v != nil {
builder.WriteString("monthly_window_start=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteByte(')')
return builder.String()
}
// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota.
type UserPlatformQuotaSlice []*UserPlatformQuota

View File

@ -0,0 +1,202 @@
// Code generated by ent, DO NOT EDIT.
package userplatformquota
import (
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
// Label holds the string label denoting the userplatformquota type in the database.
Label = "user_platform_quota"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
FieldUpdatedAt = "updated_at"
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
FieldDeletedAt = "deleted_at"
// FieldUserID holds the string denoting the user_id field in the database.
FieldUserID = "user_id"
// FieldPlatform holds the string denoting the platform field in the database.
FieldPlatform = "platform"
// FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database.
FieldDailyLimitUsd = "daily_limit_usd"
// FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database.
FieldWeeklyLimitUsd = "weekly_limit_usd"
// FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database.
FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database.
FieldDailyUsageUsd = "daily_usage_usd"
// FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database.
FieldWeeklyUsageUsd = "weekly_usage_usd"
// FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database.
FieldMonthlyUsageUsd = "monthly_usage_usd"
// FieldDailyWindowStart holds the string denoting the daily_window_start field in the database.
FieldDailyWindowStart = "daily_window_start"
// FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database.
FieldWeeklyWindowStart = "weekly_window_start"
// FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database.
FieldMonthlyWindowStart = "monthly_window_start"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// Table holds the table name of the userplatformquota in the database.
Table = "user_platform_quotas"
// UserTable is the table that holds the user relation/edge.
UserTable = "user_platform_quotas"
// UserInverseTable is the table name for the User entity.
// It exists in this package in order to avoid circular dependency with the "user" package.
UserInverseTable = "users"
// UserColumn is the table column denoting the user relation/edge.
UserColumn = "user_id"
)
// Columns holds all SQL columns for userplatformquota fields.
var Columns = []string{
FieldID,
FieldCreatedAt,
FieldUpdatedAt,
FieldDeletedAt,
FieldUserID,
FieldPlatform,
FieldDailyLimitUsd,
FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd,
FieldDailyUsageUsd,
FieldWeeklyUsageUsd,
FieldMonthlyUsageUsd,
FieldDailyWindowStart,
FieldWeeklyWindowStart,
FieldMonthlyWindowStart,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
// Note that the variables below are initialized by the runtime
// package on the initialization of the application. Therefore,
// it should be imported in the main as follows:
//
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
var (
Hooks [1]ent.Hook
Interceptors [1]ent.Interceptor
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
DefaultUpdatedAt func() time.Time
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
// PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
PlatformValidator func(string) error
// DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field.
DefaultDailyUsageUsd float64
// DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field.
DefaultWeeklyUsageUsd float64
// DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field.
DefaultMonthlyUsageUsd float64
)
// OrderOption defines the ordering options for the UserPlatformQuota queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByDeletedAt orders the results by the deleted_at field.
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
}
// ByUserID orders the results by the user_id field.
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserID, opts...).ToFunc()
}
// ByPlatform orders the results by the platform field.
func ByPlatform(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPlatform, opts...).ToFunc()
}
// ByDailyLimitUsd orders the results by the daily_limit_usd field.
func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc()
}
// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field.
func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc()
}
// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field.
func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc()
}
// ByDailyUsageUsd orders the results by the daily_usage_usd field.
func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc()
}
// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field.
func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc()
}
// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field.
func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc()
}
// ByDailyWindowStart orders the results by the daily_window_start field.
func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc()
}
// ByWeeklyWindowStart orders the results by the weekly_window_start field.
func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc()
}
// ByMonthlyWindowStart orders the results by the monthly_window_start field.
func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc()
}
// ByUserField orders the results by user field.
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
}
}
func newUserStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(UserInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
}

View File

@ -0,0 +1,799 @@
// Code generated by ent, DO NOT EDIT.
package userplatformquota
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/Wei-Shaw/sub2api/ent/predicate"
)
// ID filters vertices based on their ID field.
func ID(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
}
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
func DeletedAt(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
}
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
func UserID(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
}
// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ.
func Platform(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
}
// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ.
func DailyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
}
// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ.
func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
}
// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ.
func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
}
// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ.
func DailyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ.
func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ.
func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ.
func DailyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
}
// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ.
func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ.
func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v))
}
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
}
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v))
}
// DeletedAtIn applies the In predicate on the "deleted_at" field.
func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...))
}
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...))
}
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
func DeletedAtGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v))
}
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v))
}
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
func DeletedAtLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v))
}
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v))
}
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
func DeletedAtIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt))
}
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
func DeletedAtNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt))
}
// UserIDEQ applies the EQ predicate on the "user_id" field.
func UserIDEQ(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
}
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
func UserIDNEQ(v int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v))
}
// UserIDIn applies the In predicate on the "user_id" field.
func UserIDIn(vs ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...))
}
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...))
}
// PlatformEQ applies the EQ predicate on the "platform" field.
func PlatformEQ(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
}
// PlatformNEQ applies the NEQ predicate on the "platform" field.
func PlatformNEQ(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v))
}
// PlatformIn applies the In predicate on the "platform" field.
func PlatformIn(vs ...string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...))
}
// PlatformNotIn applies the NotIn predicate on the "platform" field.
func PlatformNotIn(vs ...string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...))
}
// PlatformGT applies the GT predicate on the "platform" field.
func PlatformGT(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v))
}
// PlatformGTE applies the GTE predicate on the "platform" field.
func PlatformGTE(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v))
}
// PlatformLT applies the LT predicate on the "platform" field.
func PlatformLT(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v))
}
// PlatformLTE applies the LTE predicate on the "platform" field.
func PlatformLTE(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v))
}
// PlatformContains applies the Contains predicate on the "platform" field.
func PlatformContains(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v))
}
// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field.
func PlatformHasPrefix(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v))
}
// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field.
func PlatformHasSuffix(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v))
}
// PlatformEqualFold applies the EqualFold predicate on the "platform" field.
func PlatformEqualFold(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v))
}
// PlatformContainsFold applies the ContainsFold predicate on the "platform" field.
func PlatformContainsFold(v string) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v))
}
// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field.
func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
}
// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field.
func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v))
}
// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field.
func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...))
}
// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field.
func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...))
}
// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field.
func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v))
}
// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field.
func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v))
}
// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field.
func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v))
}
// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field.
func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v))
}
// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field.
func DailyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd))
}
// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field.
func DailyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd))
}
// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...))
}
// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...))
}
// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v))
}
// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd))
}
// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field.
func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd))
}
// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...))
}
// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...))
}
// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v))
}
// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd))
}
// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field.
func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd))
}
// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field.
func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v))
}
// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field.
func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field.
func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...))
}
// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field.
func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field.
func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v))
}
// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field.
func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v))
}
// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...))
}
// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v))
}
// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field.
func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v))
}
// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...))
}
// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v))
}
// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field.
func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v))
}
// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field.
func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field.
func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v))
}
// DailyWindowStartIn applies the In predicate on the "daily_window_start" field.
func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field.
func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...))
}
// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field.
func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v))
}
// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field.
func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v))
}
// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field.
func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v))
}
// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field.
func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v))
}
// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field.
func DailyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart))
}
// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field.
func DailyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart))
}
// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field.
func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field.
func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...))
}
// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field.
func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field.
func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field.
func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v))
}
// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart))
}
// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field.
func WeeklyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart))
}
// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field.
func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field.
func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...))
}
// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field.
func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field.
func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field.
func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v))
}
// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartIsNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart))
}
// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field.
func MonthlyWindowStartNotNil() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
)
sqlgraph.HasNeighbors(s, step)
})
}
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(func(s *sql.Selector) {
step := newUserStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
}
})
})
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota {
return predicate.UserPlatformQuota(sql.NotPredicates(p))
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity.
type UserPlatformQuotaDelete struct {
config
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity.
type UserPlatformQuotaDeleteOne struct {
_d *UserPlatformQuotaDelete
}
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{userplatformquota.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

View File

@ -0,0 +1,643 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities.
type UserPlatformQuotaQuery struct {
config
ctx *QueryContext
order []userplatformquota.OrderOption
inters []Interceptor
predicates []predicate.UserPlatformQuota
withUser *UserQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the UserPlatformQuotaQuery builder.
func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery {
_q.order = append(_q.order, o...)
return _q
}
// QueryUser chains the current query on the "user" edge.
func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery {
query := (&UserClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
selector := _q.sqlQuery(ctx)
if err := selector.Err(); err != nil {
return nil, err
}
step := sqlgraph.NewStep(
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector),
sqlgraph.To(user.Table, user.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
)
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil
}
return query
}
// First returns the first UserPlatformQuota entity from the query.
// Returns a *NotFoundError when no UserPlatformQuota was found.
func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{userplatformquota.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first UserPlatformQuota ID from the query.
// Returns a *NotFoundError when no UserPlatformQuota ID was found.
func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{userplatformquota.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one UserPlatformQuota entity is found.
// Returns a *NotFoundError when no UserPlatformQuota entities are found.
func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{userplatformquota.Label}
default:
return nil, &NotSingularError{userplatformquota.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query.
// Returns a *NotSingularError when more than one UserPlatformQuota ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) {
var ids []int64
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{userplatformquota.Label}
default:
err = &NotSingularError{userplatformquota.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of UserPlatformQuotaSlice.
func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]()
return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of UserPlatformQuota IDs.
func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery {
if _q == nil {
return nil
}
return &UserPlatformQuotaQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]userplatformquota.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...),
withUser: _q.withUser.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// WithUser tells the query-builder to eager-load the nodes that are connected to
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery {
query := (&UserClient{config: _q.config}).Query()
for _, opt := range opts {
opt(query)
}
_q.withUser = query
return _q
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.UserPlatformQuota.Query().
// GroupBy(userplatformquota.FieldCreatedAt).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &UserPlatformQuotaGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = userplatformquota.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// CreatedAt time.Time `json:"created_at,omitempty"`
// }
//
// client.UserPlatformQuota.Query().
// Select(userplatformquota.FieldCreatedAt).
// Scan(ctx, &v)
func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q}
sbuild.label = userplatformquota.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations.
func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !userplatformquota.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) {
var (
nodes = []*UserPlatformQuota{}
_spec = _q.querySpec()
loadedTypes = [1]bool{
_q.withUser != nil,
}
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*UserPlatformQuota).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &UserPlatformQuota{config: _q.config}
nodes = append(nodes, node)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
if query := _q.withUser; query != nil {
if err := _q.loadUser(ctx, query, nodes, nil,
func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil {
return nil, err
}
}
return nodes, nil
}
func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error {
ids := make([]int64, 0, len(nodes))
nodeids := make(map[int64][]*UserPlatformQuota)
for i := range nodes {
fk := nodes[i].UserID
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(user.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
return nil
}
func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
for i := range fields {
if fields[i] != userplatformquota.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if _q.withUser != nil {
_spec.Node.AddColumnOnce(userplatformquota.FieldUserID)
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(userplatformquota.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = userplatformquota.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities.
type UserPlatformQuotaGroupBy struct {
selector
build *UserPlatformQuotaQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities.
type UserPlatformQuotaSelect struct {
*UserPlatformQuotaQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v)
}
func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

View File

@ -0,0 +1,985 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
)
// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities.
type UserPlatformQuotaUpdate struct {
config
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetPlatform sets the "platform" field.
func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate {
_u.mutation.SetPlatform(v)
return _u
}
// SetNillablePlatform sets the "platform" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetPlatform(*v)
}
return _u
}
// SetDailyLimitUsd sets the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetDailyLimitUsd()
_u.mutation.SetDailyLimitUsd(v)
return _u
}
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyLimitUsd(*v)
}
return _u
}
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddDailyLimitUsd(v)
return _u
}
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearDailyLimitUsd()
return _u
}
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetWeeklyLimitUsd()
_u.mutation.SetWeeklyLimitUsd(v)
return _u
}
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyLimitUsd(*v)
}
return _u
}
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddWeeklyLimitUsd(v)
return _u
}
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearWeeklyLimitUsd()
return _u
}
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetMonthlyLimitUsd()
_u.mutation.SetMonthlyLimitUsd(v)
return _u
}
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyLimitUsd(*v)
}
return _u
}
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddMonthlyLimitUsd(v)
return _u
}
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate {
_u.mutation.ClearMonthlyLimitUsd()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate {
return _u.SetUserID(v.ID)
}
// Mutation returns the UserPlatformQuotaMutation object of the builder.
func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate {
_u.mutation.ClearUser()
return _u
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
return 0, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserPlatformQuotaUpdate) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userplatformquota.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userplatformquota.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserPlatformQuotaUpdate) check() error {
if v, ok := _u.mutation.Platform(); ok {
if err := userplatformquota.PlatformValidator(v); err != nil {
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
}
return nil
}
func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
}
if value, ok := _u.mutation.DailyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.DailyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.WeeklyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userplatformquota.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity.
type UserPlatformQuotaUpdateOne struct {
config
fields []string
hooks []Hook
mutation *UserPlatformQuotaMutation
}
// SetUpdatedAt sets the "updated_at" field.
func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetUpdatedAt(v)
return _u
}
// SetDeletedAt sets the "deleted_at" field.
func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetDeletedAt(v)
return _u
}
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDeletedAt(*v)
}
return _u
}
// ClearDeletedAt clears the value of the "deleted_at" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDeletedAt()
return _u
}
// SetUserID sets the "user_id" field.
func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne {
_u.mutation.SetUserID(v)
return _u
}
// SetNillableUserID sets the "user_id" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetUserID(*v)
}
return _u
}
// SetPlatform sets the "platform" field.
func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne {
_u.mutation.SetPlatform(v)
return _u
}
// SetNillablePlatform sets the "platform" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetPlatform(*v)
}
return _u
}
// SetDailyLimitUsd sets the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetDailyLimitUsd()
_u.mutation.SetDailyLimitUsd(v)
return _u
}
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyLimitUsd(*v)
}
return _u
}
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddDailyLimitUsd(v)
return _u
}
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDailyLimitUsd()
return _u
}
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetWeeklyLimitUsd()
_u.mutation.SetWeeklyLimitUsd(v)
return _u
}
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyLimitUsd(*v)
}
return _u
}
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddWeeklyLimitUsd(v)
return _u
}
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearWeeklyLimitUsd()
return _u
}
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetMonthlyLimitUsd()
_u.mutation.SetMonthlyLimitUsd(v)
return _u
}
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyLimitUsd(*v)
}
return _u
}
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddMonthlyLimitUsd(v)
return _u
}
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearMonthlyLimitUsd()
return _u
}
// SetDailyUsageUsd sets the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetDailyUsageUsd()
_u.mutation.SetDailyUsageUsd(v)
return _u
}
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyUsageUsd(*v)
}
return _u
}
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddDailyUsageUsd(v)
return _u
}
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetWeeklyUsageUsd()
_u.mutation.SetWeeklyUsageUsd(v)
return _u
}
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyUsageUsd(*v)
}
return _u
}
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddWeeklyUsageUsd(v)
return _u
}
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.ResetMonthlyUsageUsd()
_u.mutation.SetMonthlyUsageUsd(v)
return _u
}
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyUsageUsd(*v)
}
return _u
}
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
_u.mutation.AddMonthlyUsageUsd(v)
return _u
}
// SetDailyWindowStart sets the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetDailyWindowStart(v)
return _u
}
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetDailyWindowStart(*v)
}
return _u
}
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearDailyWindowStart()
return _u
}
// SetWeeklyWindowStart sets the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetWeeklyWindowStart(v)
return _u
}
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetWeeklyWindowStart(*v)
}
return _u
}
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearWeeklyWindowStart()
return _u
}
// SetMonthlyWindowStart sets the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
_u.mutation.SetMonthlyWindowStart(v)
return _u
}
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
if v != nil {
_u.SetMonthlyWindowStart(*v)
}
return _u
}
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearMonthlyWindowStart()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne {
return _u.SetUserID(v.ID)
}
// Mutation returns the UserPlatformQuotaMutation object of the builder.
func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation {
return _u.mutation
}
// ClearUser clears the "user" edge to the User entity.
func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne {
_u.mutation.ClearUser()
return _u
}
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated UserPlatformQuota entity.
func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) {
if err := _u.defaults(); err != nil {
return nil, err
}
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_u *UserPlatformQuotaUpdateOne) defaults() error {
if _, ok := _u.mutation.UpdatedAt(); !ok {
if userplatformquota.UpdateDefaultUpdatedAt == nil {
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
}
v := userplatformquota.UpdateDefaultUpdatedAt()
_u.mutation.SetUpdatedAt(v)
}
return nil
}
// check runs all checks and user-defined validators on the builder.
func (_u *UserPlatformQuotaUpdateOne) check() error {
if v, ok := _u.mutation.Platform(); ok {
if err := userplatformquota.PlatformValidator(v); err != nil {
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
}
return nil
}
func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) {
if err := _u.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
for _, f := range fields {
if !userplatformquota.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
if f != userplatformquota.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.DeletedAt(); ok {
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
}
if _u.mutation.DeletedAtCleared() {
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
}
if value, ok := _u.mutation.Platform(); ok {
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
}
if value, ok := _u.mutation.DailyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.DailyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.WeeklyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
}
if _u.mutation.MonthlyLimitUsdCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
}
if value, ok := _u.mutation.DailyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
}
if value, ok := _u.mutation.DailyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
}
if _u.mutation.DailyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
}
if _u.mutation.WeeklyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
}
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
}
if _u.mutation.MonthlyWindowStartCleared() {
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
Inverse: true,
Table: userplatformquota.UserTable,
Columns: []string{userplatformquota.UserColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
_node = &UserPlatformQuota{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{userplatformquota.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

View File

@ -5,12 +5,14 @@ go 1.26.3
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alicebob/miniredis/v2 v2.38.0
github.com/alitto/pond/v2 v2.6.2
github.com/andybalholm/brotli v1.2.0
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
github.com/aws/smithy-go v1.24.2
github.com/cespare/xxhash/v2 v2.3.0
github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
@ -71,7 +73,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
@ -158,6 +159,7 @@ require (
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zclconf/go-cty v1.14.4 // indirect
github.com/zclconf/go-cty-yaml v1.1.0 // indirect

View File

@ -16,6 +16,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw=
github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw=
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
@ -162,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@ -370,6 +370,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H
github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=

View File

@ -643,6 +643,12 @@ type ProxyProbeConfig struct {
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
// UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL默认 86400=1天覆盖典型 daily 窗口。
// 消费点:
// - billing_cache_service.cacheWriteWorker 异步累加
// - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载
// 读写两端必须共用同一 TTL避免缓存生命周期不一致导致 quota 计数漂移。
UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"`
}
type CircuitBreakerConfig struct {
@ -1544,6 +1550,7 @@ func setDefaults() {
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400)
// Turnstile
viper.SetDefault("turnstile.required", false)

View File

@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
userHandler := NewUserHandler(adminSvc, nil, nil, nil)
groupHandler := NewGroupHandler(adminSvc, nil, nil)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)

View File

@ -24,6 +24,7 @@ type stubAdminService struct {
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
getUserErr error
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
if s.getUserErr != nil {
return nil, s.getUserErr
}
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil

View File

@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
// Default platform quotasJSON map
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
slog.Error("default_platform_quotas_get_failed", "error", err)
} else {
payload.DefaultPlatformQuotas = platformQuotas
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
@ -637,6 +644,18 @@ type UpdateSettingsRequest struct {
// OpenAI fast/flex policy (optional, only updated when provided)
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
// 系统全局 platform quota 默认值整体替换语义nil = 不修改non-nil = 整体覆盖)。
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"`
// auth-source 层 platform quota 覆盖override 语义nil = 不修改non-nil = 整体覆盖该 source 的 quota 配置)。
AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"`
AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"`
AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"`
AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"`
AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"`
AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"`
AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"`
}
// UpdateSettings 更新系统设置
@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
// 系统全局 platform quota 默认值(整体替换语义)
DefaultPlatformQuotas: req.DefaultPlatformQuotas,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}(),
}
// req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值);
// non-nil含 empty map表示整体覆盖empty map = 清空该 source 的所有 quota 配置。
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas),
},
LinuxDo: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas),
},
OIDC: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas),
},
WeChat: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas),
},
GitHub: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas),
},
Google: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas),
},
DingTalk: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} else if fastPolicy != nil {
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
}
// Default platform quotasJSON map—— 与 GetSettings 一致,避免保存后响应缺失该字段
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
slog.Error("default_platform_quotas_get_failed", "error", err)
} else {
payload.DefaultPlatformQuotas = platformQuotas
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.RiskControlEnabled != after.RiskControlEnabled {
changed = append(changed, "risk_control_enabled")
}
// Default platform quotasJSON map整体比较
if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) {
changed = append(changed, service.SettingKeyDefaultPlatformQuotas)
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
}
// Platform quotas diff整体替换语义发单个 JSON key。
if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) {
changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name))
}
}
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
changed = append(changed, "force_email_on_third_party_signup")
@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting,
return result
}
// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义:
// nil = 请求未包含该字段(保留 fallbacknon-nil含 empty map= 整体覆盖。
// 注意JSON null 与字段省略等价——两者均反序列化为 nil map因此都保留旧值
// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。
func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting {
if value == nil {
return fallback
}
return value
}
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
data := make(map[string]any)
raw, err := json.Marshal(settings)
@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas
data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas
data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas
data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas
data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas
data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas
data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo)
}
return placeholders
}
// equalNullableFloat compares two *float64 values treating nil as a distinct case.
func equalNullableFloat(a, b *float64) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting.
func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 {
if s == nil {
return nil
}
switch win {
case "daily":
return s.DailyLimitUSD
case "weekly":
return s.WeeklyLimitUSD
case "monthly":
return s.MonthlyLimitUSD
}
return nil
}
// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots.
func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool {
for _, platform := range service.AllowedQuotaPlatforms {
b := before[platform]
a := after[platform]
if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) {
return false
}
if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) {
return false
}
if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) {
return false
}
}
return true
}

View File

@ -0,0 +1,188 @@
//go:build unit
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) {
five := 5.0
ten := 10.0
before := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
after := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
},
}
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
found := false
for _, key := range changed {
if key == service.SettingKeyDefaultPlatformQuotas {
found = true
break
}
}
if !found {
t.Errorf("expected change detection for default platform quotas, got %v", changed)
}
}
func TestDiffSettings_NoChangeWhenEqual(t *testing.T) {
five := 5.0
before := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
after := &service.SystemSettings{
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
for _, key := range changed {
if key == service.SettingKeyDefaultPlatformQuotas {
t.Error("equal values should not be detected as changed")
}
}
}
func TestEqualNullableFloat(t *testing.T) {
five := 5.0
five2 := 5.0
ten := 10.0
cases := []struct {
a, b *float64
want bool
}{
{nil, nil, true},
{&five, nil, false},
{nil, &five, false},
{&five, &five2, true},
{&five, &ten, false},
}
for _, c := range cases {
if got := equalNullableFloat(c.a, c.b); got != c.want {
t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want)
}
}
}
func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) {
five := 5.0
ten := 10.0
before := map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
}
after := map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
}
if equalPlatformQuotaSettings(before, after) {
t.Error("expected unequal")
}
}
func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) {
five := 5.0
ten := 10.0
before := &service.AuthSourceDefaultSettings{
LinuxDo: service.ProviderDefaultGrantSettings{
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
},
}
after := &service.AuthSourceDefaultSettings{
LinuxDo: service.ProviderDefaultGrantSettings{
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &ten},
},
},
}
changed := appendAuthSourceDefaultChanges([]string{}, before, after)
// 改动 B5整体替换语义审计 log 发单个 JSON key而非展开 84 个扁平 key。
key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo")
found := false
for _, k := range changed {
if k == key {
found = true
break
}
}
if !found {
t.Errorf("expected %q in changed, got %v", key, changed)
}
}
// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复:
// PUT 发 auth_source_default_email_platform_quotasGET 能读回相同值(端到端往返)。
func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
// PUT发 email platform quotaopenai monthly=20
putBody := map[string]any{
"auth_source_default_email_platform_quotas": map[string]any{
"openai": map[string]any{
"monthly": 20,
},
},
}
rawBody, err := json.Marshal(putBody)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证 DB 中写入了 JSON key
jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email")
require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB")
// GET验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20
rec2 := httptest.NewRecorder()
c2, _ := gin.CreateTestContext(rec2)
c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
handler.GetSettings(c2)
require.Equal(t, http.StatusOK, rec2.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any)
require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map")
openaiPQ, ok := emailPQ["openai"].(map[string]any)
require.True(t, ok, "expected openai entry in email platform quotas")
monthly, ok := openaiPQ["monthly"].(float64)
require.True(t, ok, "expected monthly to be float64")
require.Equal(t, float64(20), monthly, "expected openai monthly=20")
}

View File

@ -2,10 +2,15 @@ package admin
import (
"context"
"errors"
"log/slog"
"math"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -20,15 +25,24 @@ type UserWithConcurrency struct {
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
concurrencyService *service.ConcurrencyService
adminService service.AdminService
concurrencyService *service.ConcurrencyService
userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view
billingCache service.BillingCache // T17/T18 缓存失效PUT/POST 路径)
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
func NewUserHandler(
adminService service.AdminService,
concurrencyService *service.ConcurrencyService,
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
billingCache service.BillingCache,
) *UserHandler {
return &UserHandler{
adminService: adminService,
concurrencyService: concurrencyService,
adminService: adminService,
concurrencyService: concurrencyService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
billingCache: billingCache,
}
}
@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
}
response.Success(c, gin.H{"affected": affected})
}
// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas
// admin 视角D14 lazy 归零 + 暴露 *_window_start 调试字段
func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) {
idStr := c.Param("id")
userID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid user id")
return
}
if h.userPlatformQuotaRepo == nil {
response.Success(c, map[string]any{"platform_quotas": []any{}})
return
}
// 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。
if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil {
response.ErrorFrom(c, err)
return
}
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
out := make([]map[string]any, 0, len(records))
for _, r := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas.
type UpdateUserPlatformQuotasRequest struct {
Quotas []PlatformQuotaInput `json:"quotas" binding:"required"`
}
// PlatformQuotaInput 单平台限额输入limit 字段为 nil 表示不限制。
type PlatformQuotaInput struct {
Platform string `json:"platform" binding:"required"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。
// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas
// 全量替换该用户所有平台限额。
func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) {
if h.userPlatformQuotaRepo == nil {
response.Error(c, 503, "platform quota service not available")
return
}
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserPlatformQuotasRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.Quotas) > 4 {
response.BadRequest(c, "quotas length must be <= 4")
return
}
seen := make(map[string]struct{}, len(req.Quotas))
for _, q := range req.Quotas {
if !service.IsAllowedQuotaPlatform(q.Platform) {
response.BadRequest(c, "invalid platform: "+q.Platform)
return
}
if _, dup := seen[q.Platform]; dup {
response.BadRequest(c, "duplicate platform: "+q.Platform)
return
}
seen[q.Platform] = struct{}{}
// daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义:
// nil / not set → 无限额(完全放行)
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
// > 0 → USD 限额上限
// 拦截 NaN / ±Inf客户端可发送超大数如 1e308 × 2使 JSON 反序列化得到 +Inf
// 进入 DB 后 cache check 中 usage >= limit 永不成立limit 等同失效。
for _, f := range []struct {
name string
val *float64
}{
{"daily_limit_usd", q.DailyLimitUSD},
{"weekly_limit_usd", q.WeeklyLimitUSD},
{"monthly_limit_usd", q.MonthlyLimitUSD},
} {
if f.val == nil {
continue
}
v := *f.val
if v < 0 {
response.BadRequest(c, f.name+" must be >= 0")
return
}
if math.IsNaN(v) || math.IsInf(v, 0) {
response.BadRequest(c, f.name+" must be a finite number")
return
}
}
}
records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas))
for _, q := range req.Quotas {
records = append(records, service.UserPlatformQuotaRecord{
UserID: userID,
Platform: q.Platform,
DailyLimitUSD: q.DailyLimitUSD,
WeeklyLimitUSD: q.WeeklyLimitUSD,
MonthlyLimitUSD: q.MonthlyLimitUSD,
})
}
ctx := c.Request.Context()
// 校验用户是否存在,避免 FK 违反导致 500用户不存在时返回 404。
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
response.ErrorFrom(c, err)
return
}
// 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。
// ListByUser 失败不阻断主操作best-effort仅记录降级 warn。
beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if beforeErr != nil {
slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr)
}
if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil {
response.ErrorFrom(c, err)
return
}
beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords))
for _, r := range beforeRecords {
beforeByPlatform[r.Platform] = r
}
afterPlatforms := make(map[string]struct{}, len(records))
for _, r := range records {
afterPlatforms[r.Platform] = struct{}{}
}
changes := make([]map[string]any, 0, len(records))
for _, r := range records {
entry := map[string]any{
"platform": r.Platform,
"daily_limit_usd": r.DailyLimitUSD,
"weekly_limit_usd": r.WeeklyLimitUSD,
"monthly_limit_usd": r.MonthlyLimitUSD,
}
if prev, ok := beforeByPlatform[r.Platform]; ok {
entry["before_daily_limit_usd"] = prev.DailyLimitUSD
entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD
entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD
}
changes = append(changes, entry)
}
// 补 removed 条目before 存在但 after 缺失 = 该平台被软删除。
// 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。
for _, prev := range beforeRecords {
if _, kept := afterPlatforms[prev.Platform]; kept {
continue
}
changes = append(changes, map[string]any{
"platform": prev.Platform,
"removed": true,
"before_daily_limit_usd": prev.DailyLimitUSD,
"before_weekly_limit_usd": prev.WeeklyLimitUSD,
"before_monthly_limit_usd": prev.MonthlyLimitUSD,
})
}
// before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段;
// false 时所有 entry 都会缺失 before_*_limit_usd仅有 after 视图。
slog.Info("admin.quota_updated",
"actor_admin_id", getAdminIDFromContext(c),
"target_user_id", userID,
"platform_count", len(records),
"before_snapshot_available", beforeErr == nil,
"changes", changes)
// 失效 cache对全部允许的 platform 统一 invalidate。
// Trade-off精确失效仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser
// 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素,
// 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。
if h.billingCache != nil {
for _, p := range service.AllowedQuotaPlatforms {
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil {
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err)
}
}
}
// 返回最新状态
now := time.Now().UTC()
records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]map[string]any, 0, len(records2))
for i := range records2 {
out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true))
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset.
type ResetUserPlatformQuotaWindowRequest struct {
Platform string `json:"platform" binding:"required"`
Window string `json:"window" binding:"required"`
}
var allowedWindowsForQuotaReset = map[string]struct{}{
"daily": {},
"weekly": {},
"monthly": {},
}
// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset
// 立即归零指定 (platform, window) 的用量并更新 window_start。
func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) {
if h.userPlatformQuotaRepo == nil {
response.Error(c, 503, "platform quota service not available")
return
}
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req ResetUserPlatformQuotaWindowRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !service.IsAllowedQuotaPlatform(req.Platform) {
response.BadRequest(c, "invalid platform: "+req.Platform)
return
}
if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok {
response.BadRequest(c, "invalid window: "+req.Window)
return
}
ctx := c.Request.Context()
// 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil {
if errors.Is(err, service.ErrUserPlatformQuotaNotFound) {
response.NotFound(c, "user platform quota not found")
return
}
response.ErrorFrom(c, err)
return
}
slog.Info("admin.quota_window_reset",
"actor_admin_id", getAdminIDFromContext(c),
"target_user_id", userID,
"platform", req.Platform,
"window", req.Window)
if h.billingCache != nil {
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil {
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err)
}
}
records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]map[string]any, 0, len(records))
for i := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true))
}
response.Success(c, map[string]any{"platform_quotas": out})
}

View File

@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
handler := NewUserHandler(adminSvc, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
handler := NewUserHandler(adminSvc, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)

View File

@ -0,0 +1,301 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository捕获 UpsertForUser 调用。
type upsertCapturingQuotaRepo struct {
service.UserPlatformQuotaRepository
listRecords []service.UserPlatformQuotaRecord
listErr error
upsertCalls []upsertCall
upsertErr error
resetCalls []resetCall
resetErr error
}
type upsertCall struct {
userID int64
records []service.UserPlatformQuotaRecord
}
type resetCall struct {
userID int64
platform string
window string
newStart time.Time
}
func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return r.listRecords, r.listErr
}
func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
cloned := make([]service.UserPlatformQuotaRecord, len(records))
copy(cloned, records)
r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned})
return r.upsertErr
}
func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart})
return r.resetErr
}
// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。
type billingCacheStub struct {
service.BillingCache
deleteCalls []deleteCall
deleteErr error
}
type deleteCall struct {
userID int64
platform string
}
func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error {
b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform})
return b.deleteErr
}
func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler {
return &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: newStubAdminService(),
}
}
func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
c.Request = req
c.Params = []gin.Param{{Key: "id", Value: "42"}}
return c, w
}
func TestUpdateUserPlatformQuotas_Success(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[
{"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0},
{"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if len(repo.upsertCalls) != 1 {
t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls))
}
if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 {
t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0])
}
// 缓存失效:请求中 2 个 platform + 软删除的 2 个 platformgemini, antigravity= 4 次
if len(cache.deleteCalls) != 4 {
t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls)
}
}
func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[
{"platform":"anthropic","daily_limit_usd":1},
{"platform":"anthropic","daily_limit_usd":2}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
body := `{"quotas":[
{"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"}
]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) {
repo := &upsertCapturingQuotaRepo{
listRecords: []service.UserPlatformQuotaRecord{
{UserID: 42, Platform: "anthropic"},
},
}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if !strings.Contains(w.Body.String(), `"platform_quotas"`) {
t.Errorf("response should contain platform_quotas array: %s", w.Body.String())
}
}
// ───────── T4: Reset 测试 ─────────
func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
c.Request = req
c.Params = []gin.Param{{Key: "id", Value: "42"}}
return c, w
}
func TestResetUserPlatformQuotaWindow_Success(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"platform":"anthropic","window":"daily"}`
c, w := postReq(t, body)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if len(repo.resetCalls) != 1 {
t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls))
}
if repo.resetCalls[0].userID != 42 ||
repo.resetCalls[0].platform != "anthropic" ||
repo.resetCalls[0].window != "daily" {
t.Errorf("unexpected reset call: %+v", repo.resetCalls[0])
}
if len(cache.deleteCalls) != 1 ||
cache.deleteCalls[0].userID != 42 ||
cache.deleteCalls[0].platform != "anthropic" {
t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls)
}
}
func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) {
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
c, w := postReq(t, `{"platform":"unknown","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) {
// handler 检查 service.ErrUserPlatformQuotaNotFound由 adapter 包装而来)
repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound}
h := buildTestHandler(repo, &billingCacheStub{})
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) {
repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")}
cache := &billingCacheStub{}
h := buildTestHandler(repo, cache)
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code < 500 {
t.Errorf("expected 5xx, got %d", w.Code)
}
// 返回 JSON 错误响应
var body2 map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil {
t.Errorf("expected JSON error body, got: %s", w.Body.String())
}
}
func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
adminSvc := newStubAdminService()
adminSvc.getUserErr = service.ErrUserNotFound
h := &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: adminSvc,
}
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
c, w := putReq(t, body)
h.UpdateUserPlatformQuotas(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
}
}
func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) {
repo := &upsertCapturingQuotaRepo{}
cache := &billingCacheStub{}
adminSvc := newStubAdminService()
adminSvc.getUserErr = service.ErrUserNotFound
h := &UserHandler{
userPlatformQuotaRepo: repo,
billingCache: cache,
adminService: adminSvc,
}
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
h.ResetUserPlatformQuotaWindow(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
}
}

View File

@ -0,0 +1,124 @@
//go:build unit
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type fakeQuotaRepoForAdmin struct {
service.UserPlatformQuotaRepository
records []service.UserPlatformQuotaRecord
err error
}
func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return f.records, f.err
}
func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context {
c, _ := gin.CreateTestContext(w)
req, _ := http.NewRequest(http.MethodGet, "/", nil)
c.Request = req
return c
}
func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) {
start := time.Now().Add(-1 * time.Hour)
repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{
UserID: 99, Platform: "anthropic",
DailyUsageUSD: 1.0, DailyWindowStart: &start,
}}}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "99"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), `"daily_window_start"`) {
t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String())
}
}
func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "abc"}}
h.GetUserPlatformQuotas(c)
if w.Code < 400 || w.Code >= 500 {
t.Errorf("invalid id should yield 4xx, got %d", w.Code)
}
}
func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) {
repo := &fakeQuotaRepoForAdmin{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "99"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Errorf("empty list should be 200, got %d", w.Code)
}
var body map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("response is not valid JSON: %v", err)
}
data, ok := body["data"].(map[string]any)
if !ok {
t.Fatalf("response missing data object: %v", body)
}
quotas, ok := data["platform_quotas"].([]any)
if !ok {
t.Fatalf("data.platform_quotas missing or wrong type: %v", data)
}
if len(quotas) != 0 {
t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas)
}
}
func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "1"}}
h.GetUserPlatformQuotas(c)
if w.Code != 200 {
t.Errorf("nil repo should return 200 empty, got %d", w.Code)
}
}
// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404
// (与 PUT / POST reset 端点行为一致review fix原实现返回空数组会让 admin 界面误判用户存在)
func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
repo := &fakeQuotaRepoForAdmin{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c := newAdminQuotaTestContext(w)
c.Params = []gin.Param{{Key: "id", Value: "999"}}
h.GetUserPlatformQuotas(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String())
}
}

View File

@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates (
nil,
options.defaultSubAssigner,
affiliateService,
nil,
)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
var totpSvc *service.TotpService

View File

@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()

View File

@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
nil,
nil,
nil,
nil,
)
return &AuthHandler{

View File

@ -3,6 +3,8 @@ package dto
import (
"encoding/json"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// CustomMenuItem represents a user-configured custom menu entry.
@ -246,6 +248,9 @@ type SystemSettings struct {
// OpenAI fast/flex policy
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
// 系统全局默认平台配额key = platformnil/缺省 = 不限制)
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"`
}
type DefaultSubscriptionSetting struct {

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"strconv"
"strings"
@ -250,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -508,10 +509,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@ -808,7 +811,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil {
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
@ -900,10 +903,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ParsedRequest: parsedReq,
QuotaPlatform: quotaPlatform,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
@ -1582,7 +1587,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
@ -1830,6 +1835,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response)
}
// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算
// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。
func extractQuotaResetSeconds(err error) int {
const fallback = 60
appErr := pkgerrors.FromError(err)
if appErr == nil {
return fallback
}
raw, ok := appErr.Metadata["window_resets_at"]
if !ok || raw == "" {
return fallback
}
resetAt, parseErr := time.Parse(time.RFC3339, raw)
if parseErr != nil {
logger.L().With(
zap.String("component", "handler.gateway.billing"),
zap.String("raw", raw),
zap.Error(parseErr),
).Warn("quota.invalid_window_resets_at_format")
return fallback
}
secs := time.Until(resetAt).Seconds()
if secs <= 0 {
// reset 时间已过cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避,
// 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。
return fallback
}
return int(math.Ceil(secs))
}
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
@ -1857,6 +1892,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte
retrySeconds := 60 - int(time.Now().Unix()%60)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
}
if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) ||
errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) ||
errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) {
// 与 RPM 超限一致映射 429 + Retry-After让 SDK 自动退避(而非 403 直接失败)。
// 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err)
}
msg := pkgerrors.Message(err)
if msg == "" {
logger.L().With(

View File

@ -1,8 +1,10 @@
package handler
import (
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
require.Equal(t, "billing_error", code)
require.NotEmpty(t, msg)
}
func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339),
})
got := extractQuotaResetSeconds(err)
if got < 10 || got > 11 {
t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got)
}
}
func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) {
if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 {
t.Errorf("T20: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) {
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": "not-a-time",
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T21: got %d, want 60 fallback", got)
}
}
func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) {
// 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s
// 1 秒会导致客户端立即重试仍触发限额的退避循环;
// 60s 让客户端按常规节奏退避cache/DB 自愈期间不会反复打抖。
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339),
})
if got := extractQuotaResetSeconds(err); got != 60 {
t.Errorf("T22: got %d, want 60 (fallback on past reset)", got)
}
}
func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) {
// quota 超限映射 429 + Retry-AfterRFC 6585 / 与 RPM 一致),
// 让 SDKOpenAI 兼容客户端等)能按 Retry-After 自动退避。
// 旧实现用 403 导致客户端不退避直接报错。
// 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。
cases := []struct {
name string
err error
}{
{"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
{"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
})},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(tc.err)
if status != http.StatusTooManyRequests {
t.Errorf("status = %d, want 429", status)
}
if code != "rate_limit_exceeded" {
t.Errorf("code = %q, want rate_limit_exceeded", code)
}
if retryAfter < 3599 || retryAfter > 3601 {
t.Errorf("retryAfter = %d, want ~3600", retryAfter)
}
})
}
}

View File

@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // channelService
nil, // resolver
nil, // balanceNotifyService
nil, // userPlatformQuotaRepo
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)

View File

@ -43,7 +43,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}

View File

@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
requestPayloadHash := service.HashUsageRequestPayload(body)
inboundEndpoint := GetInboundEndpoint(c)
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result,
QuotaPlatform: quotaPlatform,
APIKey: apiKey,
User: apiKey.User,
Account: account,

View File

@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *
h := &OpenAIGatewayHandler{
gatewayService: &service.OpenAIGatewayService{},
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil),
apiKeyService: &service.APIKeyService{},
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{

View File

@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {

View File

@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
@ -1235,7 +1235,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
return

View File

@ -1201,7 +1201,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
}, nil, nil, nil)
}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
gatewaySvc := service.NewOpenAIGatewayService(
accountRepo,
usageRepo,
@ -1223,6 +1223,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
channelSvc,
nil,
nil,
nil, // userPlatformQuotaRepo
)
cache := &concurrencyCacheMock{

View File

@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {

View File

@ -0,0 +1,104 @@
// Package quotaview provides shared quota response helpers for user and admin handlers.
// Extracted to avoid import cycles between handler and handler/admin packages.
package quotaview
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB
// includeWindowStart=true 时输出 *_window_start 字段admin 视角调试用)
func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any {
daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart)
weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart)
monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart)
out := map[string]any{
"platform": r.Platform,
"daily_usage_usd": daily.usage,
"daily_limit_usd": daily.limit,
"daily_window_resets_at": daily.resetsAt,
"weekly_usage_usd": weekly.usage,
"weekly_limit_usd": weekly.limit,
"weekly_window_resets_at": weekly.resetsAt,
"monthly_usage_usd": monthly.usage,
"monthly_limit_usd": monthly.limit,
"monthly_window_resets_at": monthly.resetsAt,
}
if includeWindowStart {
out["daily_window_start"] = daily.windowStart
out["weekly_window_start"] = weekly.windowStart
out["monthly_window_start"] = monthly.windowStart
}
return out
}
type windowSlice struct {
usage float64
limit *float64
resetsAt *string
windowStart *string
}
func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice {
out := windowSlice{usage: usage, limit: limit}
if expired {
out.usage = 0
out.resetsAt = nil
} else if start != nil {
s := nextReset.Format(time.RFC3339)
out.resetsAt = &s
}
if includeStart && start != nil {
s := start.Format(time.RFC3339)
out.windowStart = &s
}
return out
}
// NeedsDailyReset 判断日窗口是否已过期start 早于「全局时区当天 0 点」即过期。
// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。
func NeedsDailyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return start.Before(timezone.StartOfDay(now))
}
func NeedsWeeklyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return start.Before(timezone.StartOfWeek(now))
}
// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。
func NeedsMonthlyReset(start *time.Time, now time.Time) bool {
if start == nil {
return false
}
return now.Sub(*start) >= 30*24*time.Hour
}
func nextDailyResetTime(now time.Time) time.Time {
return timezone.StartOfDay(now).AddDate(0, 0, 1)
}
func nextWeeklyResetTime(now time.Time) time.Time {
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
}
// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。
// 语义:
// - start != nil → 返回 start + 30d与 billing_cache_service.nextMonthlyResetFrom 一致)
// - start == nil → 退化为 now + 30d保留旧行为避免 nil 崩溃)
//
// 导出(首字母大写)以允许测试直接调用。
func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time {
if start == nil {
return now.Add(30 * 24 * time.Hour)
}
return start.Add(30 * 24 * time.Hour)
}

View File

@ -0,0 +1,133 @@
package quotaview
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// TestNextMonthlyResetTimeFrom_FromStart 验证start 已知时返回 start+30d不随 now 漂移。
func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) {
t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
now := t0.Add(15 * 24 * time.Hour) // t0 + 15d
want := t0.Add(30 * 24 * time.Hour) // t0 + 30d
got := NextMonthlyResetTimeFrom(&t0, now)
if !got.Equal(want) {
t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got)
}
}
// TestNextMonthlyResetTimeFrom_NilStart 验证start=nil 时退化为 now+30d不 panic
func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) {
now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC)
want := now.Add(30 * 24 * time.Hour)
got := NextMonthlyResetTimeFrom(nil, now)
if !got.Equal(want) {
t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got)
}
}
// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证:
// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record
// monthly_window_resets_at 始终等于 windowStart+30d不随 now 漂移。
func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) {
windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339)
r := service.UserPlatformQuotaRecord{
Platform: "openai",
MonthlyUsageUSD: 5.0,
MonthlyWindowStart: &windowStart,
}
// 第一次调用now = windowStart + 5d
now1 := windowStart.Add(5 * 24 * time.Hour)
out1 := LazyZeroQuotaForResponse(r, now1, false)
resetsAt1, ok1 := out1["monthly_window_resets_at"]
if !ok1 || resetsAt1 == nil {
t.Fatal("first call: monthly_window_resets_at should be set for active window")
}
s1, ok := resetsAt1.(*string)
if !ok || s1 == nil {
t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1)
}
if *s1 != wantResetsAt {
t.Errorf("first call: want %s, got %s", wantResetsAt, *s1)
}
// 第二次调用now = windowStart + 10d不同 now但 resetsAt 应不变)
now2 := windowStart.Add(10 * 24 * time.Hour)
out2 := LazyZeroQuotaForResponse(r, now2, false)
resetsAt2, ok2 := out2["monthly_window_resets_at"]
if !ok2 || resetsAt2 == nil {
t.Fatal("second call: monthly_window_resets_at should be set for active window")
}
s2, ok := resetsAt2.(*string)
if !ok || s2 == nil {
t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2)
}
if *s2 != wantResetsAt {
t.Errorf("second call: want %s, got %s", wantResetsAt, *s2)
}
// 两次结果必须相等
if *s1 != *s2 {
t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2)
}
}
// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。
func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
// now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08北京 5/26
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC)
// start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08北京 5/25→ 应判定为过期
startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC)
if !NeedsDailyReset(&startPrevBeijingDay, now) {
t.Error("上一个北京日的窗口应判定为过期")
}
// start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08北京 5/26 同日)→ 不应过期
startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC)
if NeedsDailyReset(&startSameBeijingDay, now) {
t.Error("同一北京日的窗口不应判定为过期")
}
}
// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。
func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00
want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00
if got := nextDailyResetTime(now); !got.Equal(want) {
t.Errorf("nextDailyResetTime = %v, want %v", got, want)
}
}
// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。
func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) {
if err := timezone.Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = timezone.Init("UTC") })
// 北京 2026-05-26周二→ 下周一是 2026-06-01
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二
want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location())
if got := nextWeeklyResetTime(now); !got.Equal(want) {
t.Errorf("nextWeeklyResetTime = %v, want %v", got, want)
}
}

View File

@ -3,8 +3,10 @@ package handler
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@ -14,11 +16,12 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
userService *service.UserService
authService *service.AuthService
emailService *service.EmailService
emailCache service.EmailCache
affiliateService *service.AffiliateService
userPlatformQuotaRepo service.UserPlatformQuotaRepository
}
// NewUserHandler creates a new UserHandler
@ -28,16 +31,44 @@ func NewUserHandler(
emailService *service.EmailService,
emailCache service.EmailCache,
affiliateService *service.AffiliateService,
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
) *UserHandler {
return &UserHandler{
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
userService: userService,
authService: authService,
emailService: emailService,
emailCache: emailCache,
affiliateService: affiliateService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
}
}
// GetMyPlatformQuotas GET /user/platform-quotas
// 返回当前 JWT 用户的 platform quota 状态。
// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null不写 DB
func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
if h.userPlatformQuotaRepo == nil {
response.Success(c, map[string]any{"platform_quotas": []any{}})
return
}
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
now := time.Now().UTC()
out := make([]map[string]any, 0, len(records))
for _, r := range records {
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false))
}
response.Success(c, map[string]any{"platform_quotas": out})
}
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`

View File

@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -272,20 +276,20 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
recorder := httptest.NewRecorder()
@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
},
}
emailService := service.NewEmailService(nil, emailCache)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
recorder := httptest.NewRecorder()
@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
recorder := httptest.NewRecorder()

View File

@ -0,0 +1,212 @@
//go:build unit
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集
type fakeQuotaRepoForUserHandler struct {
service.UserPlatformQuotaRepository
records []service.UserPlatformQuotaRecord
}
func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
return f.records, nil
}
func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) {
repo := &fakeQuotaRepoForUserHandler{records: nil}
h := &UserHandler{userPlatformQuotaRepo: repo}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
}
var body struct {
Code int `json:"code"`
Data struct {
PlatformQuotas []any `json:"platform_quotas"`
} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String())
}
if body.Code != 0 {
t.Errorf("expected code=0, got %d", body.Code)
}
if body.Data.PlatformQuotas == nil {
// nil 和 empty slice 均视为可接受JSON 可能序列化为 null 或 []
// 此断言只验证 HTTP 200 + code=0 即可
}
}
func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) {
pastStart := time.Now().UTC().AddDate(0, 0, -2)
daily := 5.0
repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{
UserID: 42,
Platform: "anthropic",
DailyLimitUSD: &daily,
DailyUsageUSD: 3.0,
DailyWindowStart: &pastStart,
}}}
h := &UserHandler{userPlatformQuotaRepo: repo}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
}
// 解析 response验证过期 daily 的 usage_usd=0 且 window_resets_at=null
body := w.Body.String()
if !strings.Contains(body, `"daily_usage_usd":0`) {
t.Errorf("expected daily_usage_usd:0 in body, got: %s", body)
}
if !strings.Contains(body, `"daily_window_resets_at":null`) {
t.Errorf("expected daily_window_resets_at:null in body, got: %s", body)
}
}
func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99})
h.GetMyPlatformQuotas(c)
if w.Code != 200 {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) {
h := &UserHandler{userPlatformQuotaRepo: nil}
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
// 不设置 auth subject
h.GetMyPlatformQuotas(c)
if w.Code != 401 {
t.Fatalf("expected 401, got %d", w.Code)
}
}
func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) {
start := time.Now().UTC().Add(-1 * time.Hour)
r := service.UserPlatformQuotaRecord{
Platform: "anthropic",
DailyUsageUSD: 1.0,
DailyWindowStart: &start,
}
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false)
if _, ok := out["daily_window_start"]; ok {
t.Error("user view should not include daily_window_start")
}
}
func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) {
start := time.Now().UTC().Add(-1 * time.Hour)
r := service.UserPlatformQuotaRecord{
Platform: "anthropic",
DailyWindowStart: &start,
}
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true)
if _, ok := out["daily_window_start"]; !ok {
t.Error("admin view should include daily_window_start")
}
}
func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) {
// 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径
now := time.Now()
today := timezone.StartOfDay(now)
usage := 2.5
r := service.UserPlatformQuotaRecord{
Platform: "openai",
DailyUsageUSD: usage,
DailyWindowStart: &today,
}
out := quotaview.LazyZeroQuotaForResponse(r, now, false)
if out["daily_usage_usd"] != usage {
t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"])
}
// 活跃窗口应有 resets_at非 nil
if out["daily_window_resets_at"] == nil {
t.Error("active window should have daily_window_resets_at set")
}
}
func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsDailyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need reset")
}
}
func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) {
old := time.Now().UTC().AddDate(0, 0, -1)
if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) {
t.Error("yesterday start should need daily reset")
}
}
func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need weekly reset")
}
}
func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) {
if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) {
t.Error("nil start should not need monthly reset")
}
}
// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义C-NEW-1
func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) {
start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期
if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
t.Error("31 days ago should need monthly reset (30-day rolling)")
}
}
func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) {
start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效
if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)")
}
}
// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。
func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) {
// 窗口起始 4 月 20 日5 月 1 日仅过了 11 天,不足 30 天,不应重置
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
if quotaview.NeedsMonthlyReset(&windowStart, now) {
t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)")
}
}

View File

@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) {
_ = Now()
_ = StartOfDay(Now())
}
func TestStartOfWeek_Boundaries(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init: %v", err)
}
t.Cleanup(func() { _ = Init("UTC") })
loc := Location()
wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一
cases := []struct {
name string
in time.Time
}{
{"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)},
{"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)},
{"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := StartOfWeek(c.in); !got.Equal(wantMon) {
t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon)
}
})
}
}

View File

@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6
key := billingRateLimitKey(keyID)
return c.rdb.Del(ctx, key).Err()
}
// ============================================
// user × platform quota 缓存
// ============================================
// userPlatformQuotaCacheKey 构造 Redis key
func userPlatformQuotaCacheKey(userID int64, platform string) string {
return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform)
}
func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) {
key := userPlatformQuotaCacheKey(userID, platform)
fields := []string{
"daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version",
"daily_limit", "weekly_limit", "monthly_limit",
"daily_window_start", "weekly_window_start", "monthly_window_start",
}
vals, err := c.rdb.HMGet(ctx, key, fields...).Result()
if err != nil {
return nil, false, err
}
// 前4个全为nil → key 不存在
if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil {
return nil, false, nil
}
parseFloat := func(v any) float64 {
if v == nil {
return 0
}
s, ok := v.(string)
if !ok {
return 0
}
f, _ := strconv.ParseFloat(s, 64)
return f
}
parseFloatPtr := func(v any) *float64 {
if v == nil {
return nil
}
s, ok := v.(string)
if !ok || s == "" {
return nil
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil
}
return &f
}
parseTimePtr := func(v any) *time.Time {
if v == nil {
return nil
}
s, ok := v.(string)
if !ok || s == "" {
return nil
}
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil
}
t := time.Unix(n, 0).UTC()
return &t
}
parseInt64 := func(v any) int64 {
if v == nil {
return 0
}
s, ok := v.(string)
if !ok {
return 0
}
n, _ := strconv.ParseInt(s, 10, 64)
return n
}
return &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: parseFloat(vals[0]),
WeeklyUsageUSD: parseFloat(vals[1]),
MonthlyUsageUSD: parseFloat(vals[2]),
Version: parseInt64(vals[3]),
SchemaVersion: parseInt64(vals[4]),
DailyLimitUSD: parseFloatPtr(vals[5]),
WeeklyLimitUSD: parseFloatPtr(vals[6]),
MonthlyLimitUSD: parseFloatPtr(vals[7]),
DailyWindowStart: parseTimePtr(vals[8]),
WeeklyWindowStart: parseTimePtr(vals[9]),
MonthlyWindowStart: parseTimePtr(vals[10]),
}, true, nil
}
func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
key := userPlatformQuotaCacheKey(userID, platform)
pipe := c.rdb.TxPipeline()
// 浮点可空字段nil → 空字符串(读取时 parseFloatPtr 返回 nil表示无限额
fmtFloatPtr := func(p *float64) string {
if p == nil {
return ""
}
return strconv.FormatFloat(*p, 'f', -1, 64)
}
// time.Time 可空字段nil → 空字符串;有值 → unix 秒
fmtTimePtr := func(p *time.Time) string {
if p == nil {
return ""
}
return strconv.FormatInt(p.Unix(), 10)
}
pipe.HSet(ctx, key,
"daily_usage", entry.DailyUsageUSD,
"weekly_usage", entry.WeeklyUsageUSD,
"monthly_usage", entry.MonthlyUsageUSD,
"version", entry.Version,
"schema_version", entry.SchemaVersion,
"daily_limit", fmtFloatPtr(entry.DailyLimitUSD),
"weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD),
"monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD),
"daily_window_start", fmtTimePtr(entry.DailyWindowStart),
"weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart),
"monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart),
)
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err()
}
// updateUserPlatformQuotaUsageScript 缓存累加EXISTS + schema_version 双重守卫。
// 旧版 entryschema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后
// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。
// key 不存在同样跳过(由下次 SetCache 重建)。
// KEYS[1] = hash key
// ARGV[1] = cost (string float)
// ARGV[2] = ttl seconds
// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1)
const updateUserPlatformQuotaUsageScript = `
if redis.call("EXISTS", KEYS[1]) == 0 then
return 0
end
local ver = redis.call("HGET", KEYS[1], "schema_version")
if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then
return 0
end
redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1])
redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1])
redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1])
redis.call("HINCRBY", KEYS[1], "version", 1)
redis.call("EXPIRE", KEYS[1], ARGV[2])
return 1
`
func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
key := userPlatformQuotaCacheKey(userID, platform)
_, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key},
strconv.FormatFloat(cost, 'f', -1, 64),
int(ttl.Seconds()),
service.UserPlatformQuotaCacheSchemaV1,
).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
return nil
}

View File

@ -0,0 +1,134 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) {
t.Helper()
mr := miniredis.RunT(t)
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
return &billingCache{rdb: rdb}, mr
}
func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) {
c, _ := newMiniRedisCache(t)
entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic")
if err != nil {
t.Fatal(err)
}
if ok || entry != nil {
t.Errorf("expected miss, got ok=%v entry=%v", ok, entry)
}
}
func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
dailyLimit := 20.0
ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC)
in := &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: 1.5,
WeeklyUsageUSD: 3.0,
MonthlyUsageUSD: 10.0,
Version: 7,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
DailyLimitUSD: &dailyLimit,
DailyWindowStart: &ts,
}
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
t.Fatal(err)
}
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if err != nil || !ok {
t.Fatalf("get: ok=%v err=%v", ok, err)
}
if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 {
t.Errorf("got = %+v, want %+v", got, in)
}
if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 {
t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1)
}
if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit {
t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit)
}
if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) {
t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts)
}
}
func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
in := &service.UserPlatformQuotaCacheEntry{
DailyUsageUSD: 1.0,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
// DailyLimitUSD nil → 无限额
}
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
t.Fatal(err)
}
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if err != nil || !ok {
t.Fatalf("get: ok=%v err=%v", ok, err)
}
if got.DailyLimitUSD != nil {
t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD)
}
}
func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) {
c, _ := newMiniRedisCache(t)
if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil {
t.Fatal(err)
}
_, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai")
if ok {
t.Error("expected key absent after no-op incr")
}
}
func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
// SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{
Version: 1,
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
}, time.Minute)
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil {
t.Fatal(err)
}
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil {
t.Fatal(err)
}
got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 {
t.Errorf("got %+v, want daily/weekly/monthly=0.75", got)
}
if got.Version != 3 {
t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version)
}
}
func TestUserPlatformQuotaCache_Delete(t *testing.T) {
c, _ := newMiniRedisCache(t)
ctx := context.Background()
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute)
if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil {
t.Fatal(err)
}
_, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
if ok {
t.Error("expected miss after delete")
}
}

View File

@ -0,0 +1,91 @@
//go:build unit
package repository
import (
"context"
"errors"
"testing"
"time"
)
type fakeRepoForAdapter struct {
upsertCalledWith []UserPlatformQuotaRecord
upsertCalledUserID int64
upsertErr error
resetCalledWith [4]any // userID, platform, window, newStart
resetErr error
}
func (f *fakeRepoForAdapter) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error {
return nil
}
func (f *fakeRepoForAdapter) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
return nil, nil
}
func (f *fakeRepoForAdapter) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
return nil, nil
}
func (f *fakeRepoForAdapter) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
return nil
}
func (f *fakeRepoForAdapter) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
f.resetCalledWith = [4]any{userID, platform, window, newStart}
return f.resetErr
}
func (f *fakeRepoForAdapter) UpsertForUser(_ context.Context, userID int64, records []UserPlatformQuotaRecord) error {
f.upsertCalledUserID = userID
f.upsertCalledWith = records
return f.upsertErr
}
func TestGenericAdapter_UpsertForUser_ForwardsRecords(t *testing.T) {
fake := &fakeRepoForAdapter{}
adapter := NewUserPlatformQuotaServiceAdapter(fake)
err := adapter.UpsertForUser(context.Background(), 42, nil)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if fake.upsertCalledUserID != 42 {
t.Errorf("forwarded userID = %d, want 42", fake.upsertCalledUserID)
}
}
func TestGenericAdapter_UpsertForUser_PropagatesError(t *testing.T) {
wantErr := errors.New("boom")
fake := &fakeRepoForAdapter{upsertErr: wantErr}
adapter := NewUserPlatformQuotaServiceAdapter(fake)
err := adapter.UpsertForUser(context.Background(), 1, nil)
if !errors.Is(err, wantErr) {
t.Errorf("expected %v, got %v", wantErr, err)
}
}
func TestGenericAdapter_ResetExpiredWindow_ForwardsAllParams(t *testing.T) {
fake := &fakeRepoForAdapter{}
adapter := NewUserPlatformQuotaServiceAdapter(fake)
now := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC)
if err := adapter.ResetExpiredWindow(context.Background(), 7, "openai", "weekly", now); err != nil {
t.Fatalf("unexpected err: %v", err)
}
if fake.resetCalledWith[0].(int64) != 7 ||
fake.resetCalledWith[1].(string) != "openai" ||
fake.resetCalledWith[2].(string) != "weekly" ||
!fake.resetCalledWith[3].(time.Time).Equal(now) {
t.Errorf("forwarded params mismatch: %+v", fake.resetCalledWith)
}
}
func TestGenericAdapter_ResetExpiredWindow_PropagatesError(t *testing.T) {
wantErr := errors.New("not found")
fake := &fakeRepoForAdapter{resetErr: wantErr}
adapter := NewUserPlatformQuotaServiceAdapter(fake)
err := adapter.ResetExpiredWindow(context.Background(), 1, "a", "daily", time.Now())
if !errors.Is(err, wantErr) {
t.Errorf("expected %v, got %v", wantErr, err)
}
}

View File

@ -0,0 +1,416 @@
package repository
import (
"context"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
// UserPlatformQuotaRecord 是 repository 层的传输结构体,
// 与 ent.UserPlatformQuota 实体解耦,供业务层使用。
type UserPlatformQuotaRecord struct {
UserID int64
Platform string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
DailyUsageUSD float64
WeeklyUsageUSD float64
MonthlyUsageUSD float64
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
}
// ErrUserPlatformQuotaNotFound 用于 ResetExpiredWindow 等需要"必须命中已有记录"的方法。
var ErrUserPlatformQuotaNotFound = fmt.Errorf("user platform quota record not found")
// UserPlatformQuotaRepository 定义用户平台配额的数据访问接口。
type UserPlatformQuotaRepository interface {
// BulkInsertInitial 幂等批量插入初始配额记录ON CONFLICT DO NOTHING
BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error
// GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。
GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error)
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error)
// IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。
IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error
// ResetExpiredWindow 重置指定窗口daily/weekly/monthly的用量与起始时间。
ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error
// UpsertForUser 全量替换该用户所有平台限额配置(详见 service.UserPlatformQuotaRepository.UpsertForUser
UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error
}
type userPlatformQuotaRepository struct {
client *dbent.Client
}
// NewUserPlatformQuotaRepository 创建 UserPlatformQuotaRepository 实现。
func NewUserPlatformQuotaRepository(client *dbent.Client) UserPlatformQuotaRepository {
return &userPlatformQuotaRepository{client: client}
}
// BulkInsertInitial 用原生 SQL ON CONFLICT 实现幂等批量插入(带条件 limit 覆盖)。
// 仅插入 limit_usd 与元数据usage_usd 用 DB 默认 0window_start 留 NULL。
// FK 约束要求 user_id 在 users 表中存在,调用方负责保证。
//
// 冲突策略CASE WHEN existing.*_limit_usd IS NULL THEN EXCLUDED.*_limit_usd ELSE existing ...
// - 若 IncrementUsageWithReset 因时序问题已先建行limit 全 NULL
// 此处会把注册时的默认 limit 写入,避免该用户在该平台永久无限额。
// - 若管理员已通过 UpsertForUser 设置了非 NULL 个性化 limit**保留不动**
// —— 旧实现无条件 EXCLUDED 覆盖会丢失个性化配置。
// - 不会改 usage_usd / window_start保留累计的用量。
// - 仅命中 deleted_at IS NULL 的活跃记录partial unique index 作用域)。
func (r *userPlatformQuotaRepository) BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error {
if len(records) == 0 {
return nil
}
client := clientFromContext(ctx, r.client)
var sb strings.Builder
_, _ = sb.WriteString("INSERT INTO user_platform_quotas (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) VALUES ")
args := make([]any, 0, len(records)*6)
// 统一时间戳:避免循环内多次 time.Now() 让同一批记录的 created_at/updated_at
// 出现亚毫秒级偏差(与 UpsertForUser 的 now := time.Now() 风格一致)。
now := time.Now()
for i, rec := range records {
base := i * 6
if i > 0 {
_, _ = sb.WriteString(",")
}
fmt.Fprintf(&sb, "($%d,$%d,$%d,$%d,$%d,0,0,0,$%d,$%d)",
base+1, base+2, base+3, base+4, base+5, base+6, base+6)
args = append(args,
rec.UserID, rec.Platform,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
now,
)
}
// 精确命中 partial unique indexdeleted_at IS NULL避免对软删记录的歧义冲突。
// 条件覆盖:仅在现有 limit 为 NULL 时才写入 EXCLUDED否则保留现有非 NULL 值。
// - 修复 IncrementUsageWithReset 已用 NULL limit 建行的场景NULL → 注册默认)
// - 保护管理员通过 UpsertForUser 设置的个性化 limit 不被静默覆盖
_, _ = sb.WriteString(` ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL
DO UPDATE SET
daily_limit_usd = COALESCE(user_platform_quotas.daily_limit_usd, EXCLUDED.daily_limit_usd),
weekly_limit_usd = COALESCE(user_platform_quotas.weekly_limit_usd, EXCLUDED.weekly_limit_usd),
monthly_limit_usd = COALESCE(user_platform_quotas.monthly_limit_usd, EXCLUDED.monthly_limit_usd),
updated_at = EXCLUDED.updated_at`)
_, err := client.ExecContext(ctx, sb.String(), args...)
return err
}
// GetByUserPlatform 通过 ent 查询单条配额(排除软删除)。未找到返回 (nil, nil)。
func (r *userPlatformQuotaRepository) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) {
client := clientFromContext(ctx, r.client)
entity, err := client.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
).
Only(ctx)
if dbent.IsNotFound(err) {
return nil, nil
}
if err != nil {
return nil, err
}
return entQuotaToRecord(entity), nil
}
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
func (r *userPlatformQuotaRepository) ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) {
client := clientFromContext(ctx, r.client)
rows, err := client.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.DeletedAtIsNil(),
).
All(ctx)
if err != nil {
return nil, err
}
out := make([]UserPlatformQuotaRecord, 0, len(rows))
for _, e := range rows {
out = append(out, *entQuotaToRecord(e))
}
return out, nil
}
// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的 *_usage_usd。
// 行为:
// - 若记录存在:在事务内 SELECT FOR UPDATE按 (prev_window_start vs current_window_start)
// 判断是否需要重置(不同 = 重置为 cost相同 = 累加 cost
// - 若记录不存在fail-open create 分支):插入新记录,**limit 字段保留 nil无限制**
// —— 这是预期行为billing 链路不能因 quota 表缺失而阻断请求,未注册路径
// 的用户 quota 默认放行,由调度层指标观测 + 后台对账补建 limit
//
// 上层正常路径(注册时 BulkInsertInitial保证 limit 在记录创建时就被写入。
func (r *userPlatformQuotaRepository) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
existing, err := txClient.UserPlatformQuota.Query().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
).
ForUpdate().
Only(txCtx)
if dbent.IsNotFound(err) {
// fail-open 建行limit_* 保留 NULL无限额
// 用 ON CONFLICT DO UPDATE 累加,而非裸 INSERT并发下另一请求可能在本事务
// SELECT FOR UPDATE 之后、INSERT 之前刚建行,裸 INSERT 会撞 partial unique index
// 致事务回滚、本次 cost 丢失DO UPDATE 把 cost 累加到既有 usage 上。
// 写法与本文件 insertLimitsRow / BulkInsertInitial 的 ON CONFLICT 一致。
const insertSQL = `INSERT INTO user_platform_quotas
(user_id, platform, daily_usage_usd, weekly_usage_usd, monthly_usage_usd,
daily_window_start, weekly_window_start, monthly_window_start, created_at, updated_at)
VALUES ($1, $2, $3, $3, $3, $4, $5, $6, $7, $7)
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO UPDATE SET
daily_usage_usd = user_platform_quotas.daily_usage_usd + EXCLUDED.daily_usage_usd,
weekly_usage_usd = user_platform_quotas.weekly_usage_usd + EXCLUDED.weekly_usage_usd,
monthly_usage_usd = user_platform_quotas.monthly_usage_usd + EXCLUDED.monthly_usage_usd,
updated_at = EXCLUDED.updated_at`
// $6 = now30 天滚动月度窗口以当前时刻为起始
_, e := txClient.ExecContext(txCtx, insertSQL,
userID, platform, cost,
timezone.StartOfDay(now), timezone.StartOfWeek(now), now, now)
return e
}
if err != nil {
return err
}
newDaily := maybeReset(existing.DailyUsageUsd, existing.DailyWindowStart, timezone.StartOfDay(now), cost)
newWeekly := maybeReset(existing.WeeklyUsageUsd, existing.WeeklyWindowStart, timezone.StartOfWeek(now), cost)
// 30 天滚动月度窗口:过期时重置为 cost 并以 now 为新起始,否则累加保留原起始
newMonthly, newMonthlyStart := monthlyMaybeReset(existing.MonthlyUsageUsd, existing.MonthlyWindowStart, cost, now)
_, e := existing.Update().
SetDailyUsageUsd(newDaily).
SetWeeklyUsageUsd(newWeekly).
SetMonthlyUsageUsd(newMonthly).
SetDailyWindowStart(timezone.StartOfDay(now)).
SetWeeklyWindowStart(timezone.StartOfWeek(now)).
SetMonthlyWindowStart(newMonthlyStart). // 30 天滚动:仅过期时更新起始
Save(txCtx)
return e
})
}
// ResetExpiredWindow 无条件重置指定窗口daily/weekly/monthly的用量与起始时间。
//
// ⚠️ 命名警告NOT a "check-then-reset" helper
//
// 名字里的 "Expired" 是历史遗留,**实现并不校验窗口是否真的过期**。
// 任何调用都会无条件把对应窗口的 *_usage_usd 清零并重写 *_window_start。
// 目前唯一合法 caller 是 admin POST /reset 接口(管理员强制归零)。
//
// 如果你想要"仅在窗口过期才重置"的语义,请直接使用 IncrementUsageWithReset
// 的内部判断maybeReset / monthlyMaybeReset或新增独立函数
// 不要复用这里的函数,否则会出现"明明窗口未过期,用量却被清零"的隐蔽 bug。
//
// 未命中活跃记录时返回 ErrUserPlatformQuotaNotFound。
func (r *userPlatformQuotaRepository) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
client := clientFromContext(ctx, r.client)
upd := client.UserPlatformQuota.Update().
Where(
userplatformquota.UserIDEQ(userID),
userplatformquota.PlatformEQ(platform),
userplatformquota.DeletedAtIsNil(),
)
switch window {
case "daily":
upd = upd.SetDailyUsageUsd(0).SetDailyWindowStart(newStart)
case "weekly":
upd = upd.SetWeeklyUsageUsd(0).SetWeeklyWindowStart(newStart)
case "monthly":
upd = upd.SetMonthlyUsageUsd(0).SetMonthlyWindowStart(newStart)
default:
return fmt.Errorf("unknown window %q", window)
}
n, err := upd.Save(ctx)
if err != nil {
return err
}
if n == 0 {
return ErrUserPlatformQuotaNotFound
}
return nil
}
// withTx 在事务中执行 fn若 ctx 中已有事务则复用。
func (r *userPlatformQuotaRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client())
}
tx, err := r.client.Tx(ctx)
if err != nil {
return fmt.Errorf("begin user_platform_quota transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx, tx.Client()); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit user_platform_quota transaction: %w", err)
}
return nil
}
// entQuotaToRecord 将 ent entity 映射为 repository record。
// 注意 ent 生成字段名为 DailyLimitUsd非 DailyLimitUSD
func entQuotaToRecord(e *dbent.UserPlatformQuota) *UserPlatformQuotaRecord {
return &UserPlatformQuotaRecord{
UserID: e.UserID,
Platform: e.Platform,
DailyLimitUSD: e.DailyLimitUsd,
WeeklyLimitUSD: e.WeeklyLimitUsd,
MonthlyLimitUSD: e.MonthlyLimitUsd,
DailyUsageUSD: e.DailyUsageUsd,
WeeklyUsageUSD: e.WeeklyUsageUsd,
MonthlyUsageUSD: e.MonthlyUsageUsd,
DailyWindowStart: e.DailyWindowStart,
WeeklyWindowStart: e.WeeklyWindowStart,
MonthlyWindowStart: e.MonthlyWindowStart,
}
}
// maybeReset 判断是否需要重置窗口用量:
// - 若 prevStart 为 nil 或与 currStart 不同,表示窗口已过期,返回 cost重置
// - 否则返回 prevUsage + cost累加
func maybeReset(prevUsage float64, prevStart *time.Time, currStart time.Time, cost float64) float64 {
if prevStart == nil || !prevStart.Equal(currStart) {
return cost
}
return prevUsage + cost
}
// monthlyMaybeReset 判断 30 天滚动月度窗口是否需要重置。
// 过期条件prevStart 为 nil 或 now - prevStart >= 30×24h与订阅模式 NeedsMonthlyReset 语义一致)。
// 过期时重置为 cost否则累加。返回 (newUsage, newWindowStart)。
func monthlyMaybeReset(prevUsage float64, prevStart *time.Time, cost float64, now time.Time) (float64, time.Time) {
if prevStart == nil || now.Sub(*prevStart) >= 30*24*time.Hour {
return cost, now
}
return prevUsage + cost, *prevStart
}
// UpsertForUser 全量替换该用户的所有平台限额(事务内):
// 1. 软删除未在 records 中出现的所有 active 行
// 2. 对每条 record 尝试 UPDATE含 deleted_at = NULL 兼容重激活);
// UPDATE 行数为 0 时 INSERT 新行
//
// 仅改 *_limit_usd + deleted_at + updated_at保留 *_usage_usd / *_window_start。
func (r *userPlatformQuotaRepository) UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error {
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
platforms := make([]string, 0, len(records))
for _, rec := range records {
platforms = append(platforms, rec.Platform)
}
now := time.Now()
if err := softDeleteMissingPlatforms(txCtx, txClient, userID, platforms, now); err != nil {
return err
}
for _, rec := range records {
affected, err := updateLimitsRow(txCtx, txClient, userID, rec, now)
if err != nil {
return err
}
if affected == 0 {
if err := insertLimitsRow(txCtx, txClient, userID, rec, now); err != nil {
return err
}
}
}
return nil
})
}
// softDeleteMissingPlatforms 软删除该用户所有不在 keepPlatforms 中的 active 行。
// keepPlatforms 为空时 → 软删用户所有 active 行。
// now 由调用方传入,与 updateLimitsRow / insertLimitsRow 共享同一个 Go time.Now()
// 保证事务内所有时间戳一致(避免 Postgres NOW() 与 Go time.Now() 的微小偏差)。
func softDeleteMissingPlatforms(ctx context.Context, client *dbent.Client, userID int64, keepPlatforms []string, now time.Time) error {
var (
query string
args []any
)
if len(keepPlatforms) == 0 {
query = `UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
WHERE user_id = $1 AND deleted_at IS NULL`
args = []any{userID, now}
} else {
placeholders := make([]string, len(keepPlatforms))
args = make([]any, 0, len(keepPlatforms)+2)
args = append(args, userID, now)
for i, p := range keepPlatforms {
placeholders[i] = fmt.Sprintf("$%d", i+3)
args = append(args, p)
}
query = fmt.Sprintf(`UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
WHERE user_id = $1 AND deleted_at IS NULL AND platform NOT IN (%s)`,
strings.Join(placeholders, ","))
}
_, err := client.ExecContext(ctx, query, args...)
return err
}
// updateLimitsRow 尝试 UPDATE active 行deleted_at IS NULL返回受影响行数。
// 仅更新 active 行:若存在多条历史软删记录,加 deleted_at IS NULL 守卫可避免
// 批量重激活导致的 partial unique indexuserplatformquota_user_id_platform_uq冲突。
// affected=0 时由调用方 UpsertForUser 走 insertLimitsRow 路径创建新行。
func updateLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) (int64, error) {
const query = `UPDATE user_platform_quotas
SET daily_limit_usd = $1, weekly_limit_usd = $2, monthly_limit_usd = $3,
deleted_at = NULL, updated_at = $4
WHERE user_id = $5 AND platform = $6 AND deleted_at IS NULL`
res, err := client.ExecContext(ctx, query,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, now,
userID, rec.Platform)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// insertLimitsRow 插入新限额行usage 默认 0window_start 默认 NULL
// 带 ON CONFLICT ... DO NOTHING 守卫:防止两个并发请求同时为同一 user/platform 新建行时
// 触发 unique constraint 违反userplatformquota_user_id_platform_uq 部分唯一索引)。
// affected=0 时说明另一个并发请求刚完成 INSERTfallback 到 updateLimitsRow 覆写 limits 值。
func insertLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) error {
const query = `INSERT INTO user_platform_quotas
(user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd,
daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, 0, 0, 0, $6, $6)
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO NOTHING`
res, err := client.ExecContext(ctx, query,
userID, rec.Platform,
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
now)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
// 并发情形另一请求已插入该行fallback 到 UPDATE 覆写 limits 值last-writer-wins
_, err = updateLimitsRow(ctx, client, userID, rec, now)
return err
}
return nil
}

View File

@ -0,0 +1,269 @@
//go:build integration
package repository
import (
"context"
"errors"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// mustCreateUserForQuota 在指定 client 上创建测试用户(满足 FK 约束)。
func mustCreateUserForQuota(t *testing.T, client *dbent.Client) int64 {
t.Helper()
u := mustCreateUser(t, client, &service.User{
Email: fmt.Sprintf("quota-test-%d@example.com", time.Now().UnixNano()),
})
return u.ID
}
func TestUserPlatformQuotaRepository_BulkInsertInitial_Idempotent(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
daily := 5.0
records := []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily},
{UserID: userID, Platform: "openai"},
}
// 第一次插入
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "first insert")
// 第二次插入应为 no-opON CONFLICT DO NOTHING
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "second insert (idempotent)")
list, err := repo.ListByUser(txCtx, userID)
require.NoError(t, err, "list")
require.Len(t, list, 2, "expected 2 records after idempotent insert")
// 校验 daily_limit_usd 保留
var anthropicRec *UserPlatformQuotaRecord
for i := range list {
if list[i].Platform == "anthropic" {
anthropicRec = &list[i]
}
}
require.NotNil(t, anthropicRec, "anthropic record should exist")
require.NotNil(t, anthropicRec.DailyLimitUSD, "daily limit should be set")
require.InDelta(t, 5.0, *anthropicRec.DailyLimitUSD, 1e-9, "daily limit should be 5.0")
}
func TestUserPlatformQuotaRepository_BulkInsertInitial_Empty(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
repo := NewUserPlatformQuotaRepository(client)
// 空切片不应报错
require.NoError(t, repo.BulkInsertInitial(txCtx, nil))
require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{}))
}
func TestUserPlatformQuotaRepository_GetByUserPlatform(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
// 未插入时应返回 nil
rec, err := repo.GetByUserPlatform(txCtx, userID, "anthropic")
require.NoError(t, err, "get before insert should not error")
require.Nil(t, rec, "get before insert should return nil")
// 插入后查询
daily := 10.0
require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily},
}))
rec, err = repo.GetByUserPlatform(txCtx, userID, "anthropic")
require.NoError(t, err)
require.NotNil(t, rec)
require.Equal(t, userID, rec.UserID)
require.Equal(t, "anthropic", rec.Platform)
require.NotNil(t, rec.DailyLimitUSD)
require.InDelta(t, 10.0, *rec.DailyLimitUSD, 1e-9)
}
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_SameWindow(t *testing.T) {
ctx := context.Background()
// IncrementUsageWithReset 内部自己开事务,使用独立 ent client 确保跨事务可见
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五
// 首次调用:应新建记录
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.5, now))
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.NotNil(t, rec)
require.InDelta(t, 1.5, rec.DailyUsageUSD, 1e-9, "initial daily usage")
require.InDelta(t, 1.5, rec.WeeklyUsageUSD, 1e-9, "initial weekly usage")
require.InDelta(t, 1.5, rec.MonthlyUsageUSD, 1e-9, "initial monthly usage")
// 同一天再次调用:应累加
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 0.5, now))
rec, err = repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "accumulated daily usage")
require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "accumulated weekly usage")
require.InDelta(t, 2.0, rec.MonthlyUsageUSD, 1e-9, "accumulated monthly usage")
}
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_DailyReset(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
day1 := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五(同一周、同一月)
day2 := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) // 周六(同一周、同一月)
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.0, day1))
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.0, day2))
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.InDelta(t, 1.0, rec.DailyUsageUSD, 1e-9, "daily should reset to 1.0")
require.InDelta(t, 4.0, rec.WeeklyUsageUSD, 1e-9, "weekly should accumulate to 4.0 (same week)")
require.InDelta(t, 4.0, rec.MonthlyUsageUSD, 1e-9, "monthly should accumulate to 4.0 (same month)")
}
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_WeeklyReset(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
// 5月22日周五和 5月25日下周一不同周
fri := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC)
nextMon := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) // 下一周周一
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 5.0, fri))
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 2.0, nextMon))
rec, err := repo.GetByUserPlatform(ctx, userID, "openai")
require.NoError(t, err)
require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "daily resets to new cost")
require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "weekly resets (new week)")
require.InDelta(t, 7.0, rec.MonthlyUsageUSD, 1e-9, "monthly accumulates (same month)")
}
func TestUserPlatformQuotaRepository_ResetExpiredWindow(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
// 先通过 ent 直接建一条记录
_, err := client.UserPlatformQuota.Create().
SetUserID(userID).
SetPlatform("gemini").
SetDailyUsageUsd(10.0).
SetWeeklyUsageUsd(20.0).
SetMonthlyUsageUsd(50.0).
SetDailyWindowStart(time.Date(2026, 5, 21, 0, 0, 0, 0, time.UTC)).
SetWeeklyWindowStart(time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)).
SetMonthlyWindowStart(time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)).
Save(txCtx)
require.NoError(t, err)
newStart := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
require.NoError(t, repo.ResetExpiredWindow(txCtx, userID, "gemini", "daily", newStart))
rec, err := repo.GetByUserPlatform(txCtx, userID, "gemini")
require.NoError(t, err)
require.InDelta(t, 0.0, rec.DailyUsageUSD, 1e-9, "daily usage reset to 0")
require.NotNil(t, rec.DailyWindowStart)
require.True(t, rec.DailyWindowStart.Equal(newStart), "daily window start updated")
// 其他窗口不变
require.InDelta(t, 20.0, rec.WeeklyUsageUSD, 1e-9, "weekly usage unchanged")
require.InDelta(t, 50.0, rec.MonthlyUsageUSD, 1e-9, "monthly usage unchanged")
}
func TestUserPlatformQuotaRepository_ResetExpiredWindow_UnknownWindow(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUserPlatformQuotaRepository(client)
err := repo.ResetExpiredWindow(ctx, 999, "anthropic", "yearly", time.Now())
require.Error(t, err, "unknown window should return error")
}
func TestUserPlatformQuotaRepository_BulkInsertInitial_MultiRow(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
txCtx := dbent.NewTxContext(ctx, tx)
client := tx.Client()
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
d1, d2, d3 := 5.0, 10.0, 15.0
records := []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1},
{UserID: userID, Platform: "openai", DailyLimitUSD: &d2},
{UserID: userID, Platform: "gemini", DailyLimitUSD: &d3},
}
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "multi-row insert failed")
list, err := repo.ListByUser(txCtx, userID)
require.NoError(t, err)
require.Len(t, list, 3, "expected 3 rows, got %d", len(list))
// 验证 limit 值与传入一致(防占位符串位)
byPlatform := map[string]*UserPlatformQuotaRecord{}
for i := range list {
byPlatform[list[i].Platform] = &list[i]
}
require.NotNil(t, byPlatform["anthropic"], "anthropic record should exist")
require.NotNil(t, byPlatform["anthropic"].DailyLimitUSD, "anthropic daily limit should be set")
require.InDelta(t, 5.0, *byPlatform["anthropic"].DailyLimitUSD, 1e-9, "anthropic daily_limit = want 5.0")
require.NotNil(t, byPlatform["openai"], "openai record should exist")
require.NotNil(t, byPlatform["openai"].DailyLimitUSD, "openai daily limit should be set")
require.InDelta(t, 10.0, *byPlatform["openai"].DailyLimitUSD, 1e-9, "openai daily_limit = want 10.0")
require.NotNil(t, byPlatform["gemini"], "gemini record should exist")
require.NotNil(t, byPlatform["gemini"].DailyLimitUSD, "gemini daily limit should be set")
require.InDelta(t, 15.0, *byPlatform["gemini"].DailyLimitUSD, 1e-9, "gemini daily_limit = want 15.0")
}
func TestUserPlatformQuotaRepository_ResetExpiredWindow_NotFoundReturnsSentinel(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := NewUserPlatformQuotaRepository(client)
err := repo.ResetExpiredWindow(ctx, 99999, "anthropic", "daily", time.Now())
require.True(t, errors.Is(err, ErrUserPlatformQuotaNotFound),
"expected ErrUserPlatformQuotaNotFound, got %v", err)
}

View File

@ -0,0 +1,103 @@
//go:build unit
package repository
import (
"os"
"strings"
"testing"
"time"
)
func TestMaybeReset(t *testing.T) {
start := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
other := start.AddDate(0, 0, -1)
cases := []struct {
name string
prevUsage float64
prevStart *time.Time
currStart time.Time
cost float64
want float64
}{
{"nil prev start resets", 10, nil, start, 1.5, 1.5},
{"different start resets", 10, &other, start, 1.5, 1.5},
{"same start accumulates", 10, &start, start, 1.5, 11.5},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
if got := maybeReset(c.prevUsage, c.prevStart, c.currStart, c.cost); got != c.want {
t.Errorf("maybeReset = %v, want %v", got, c.want)
}
})
}
}
// TestMonthlyMaybeReset_NilStart 验证 prevStart=nil 时重置。
func TestMonthlyMaybeReset_NilStart(t *testing.T) {
now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC)
usage, start := monthlyMaybeReset(10.0, nil, 1.5, now)
if usage != 1.5 {
t.Errorf("usage = %v, want 1.5", usage)
}
if !start.Equal(now) {
t.Errorf("start = %v, want %v", start, now)
}
}
// TestMonthlyMaybeReset_Expired 验证窗口满 30 天时重置30 天恰好到期)。
func TestMonthlyMaybeReset_Expired(t *testing.T) {
windowStart := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
// now = windowStart + 30d刚好到期
now := windowStart.Add(30 * 24 * time.Hour)
usage, start := monthlyMaybeReset(8.0, &windowStart, 2.0, now)
if usage != 2.0 {
t.Errorf("usage = %v, want 2.0 (reset)", usage)
}
if !start.Equal(now) {
t.Errorf("start = %v, want %v (new window)", start, now)
}
}
// TestMonthlyMaybeReset_CrossMonthBoundary 验证跨自然月时也使用 30 天滚动(不提前重置)。
// 旧行为5 月 1 日跨月立即重置;新行为:窗口起始 4 月 20 日5 月 1 日仅过了 11 天,应累加。
func TestMonthlyMaybeReset_CrossMonthBoundary(t *testing.T) {
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
// 5 月 1 日:距起始 11 天,不足 30 天,应累加
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
usage, start := monthlyMaybeReset(5.0, &windowStart, 1.0, now)
if usage != 6.0 {
t.Errorf("usage = %v, want 6.0 (accumulate, not reset at month boundary)", usage)
}
if !start.Equal(windowStart) {
t.Errorf("start = %v, want %v (preserved)", start, windowStart)
}
}
// TestMonthlyMaybeReset_Active 验证窗口内正常累加。
func TestMonthlyMaybeReset_Active(t *testing.T) {
windowStart := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
// 15 天内,窗口有效
now := windowStart.Add(15 * 24 * time.Hour)
usage, start := monthlyMaybeReset(3.0, &windowStart, 0.5, now)
if usage != 3.5 {
t.Errorf("usage = %v, want 3.5", usage)
}
if !start.Equal(windowStart) {
t.Errorf("start = %v, want %v", start, windowStart)
}
}
// TestUpdateLimitsRowQuery_HasDeletedAtGuard 通过读取源文件验证 updateLimitsRow
// 的 SQL WHERE 子句包含 deleted_at IS NULL 守卫I-NEW-1
// 此防回归测试可在无 DB 的 CI 环境中运行,防止意外删除该守卫。
func TestUpdateLimitsRowQuery_HasDeletedAtGuard(t *testing.T) {
src, err := os.ReadFile("user_platform_quota_repo.go")
if err != nil {
t.Fatalf("failed to read source file: %v", err)
}
const guard = "AND deleted_at IS NULL"
if !strings.Contains(string(src), guard) {
t.Errorf("updateLimitsRow SQL must contain %q to prevent bulk reactivation of soft-deleted rows (I-NEW-1)", guard)
}
}

View File

@ -0,0 +1,206 @@
package repository
import (
"context"
"errors"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// userPlatformQuotaServiceAdapter 将 repository 层的 userPlatformQuotaRepository
// 适配为 service.UserPlatformQuotaRepository 接口(返回 *service.UserPlatformQuotaRecord
type userPlatformQuotaServiceAdapter struct {
inner *userPlatformQuotaRepository
}
// NewUserPlatformQuotaServiceAdapter 将 UserPlatformQuotaRepository 实现包装为
// 满足 service.UserPlatformQuotaRepository 接口的适配器。
func NewUserPlatformQuotaServiceAdapter(repo UserPlatformQuotaRepository) service.UserPlatformQuotaRepository {
impl, ok := repo.(*userPlatformQuotaRepository)
if !ok {
// 非标准实现(如测试 fake通过通用适配器包装
return &genericUserPlatformQuotaAdapter{inner: repo}
}
return &userPlatformQuotaServiceAdapter{inner: impl}
}
func (a *userPlatformQuotaServiceAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) {
rec, err := a.inner.GetByUserPlatform(ctx, userID, platform)
if err != nil || rec == nil {
return nil, err
}
return toServiceRecord(rec), nil
}
// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的用量。
func (a *userPlatformQuotaServiceAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now)
}
// ListByUser 查询用户的所有平台配额记录。
func (a *userPlatformQuotaServiceAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) {
rows, err := a.inner.ListByUser(ctx, userID)
if err != nil {
return nil, err
}
out := make([]service.UserPlatformQuotaRecord, len(rows))
for i, r := range rows {
out[i] = service.UserPlatformQuotaRecord{
UserID: r.UserID,
Platform: r.Platform,
DailyLimitUSD: r.DailyLimitUSD,
WeeklyLimitUSD: r.WeeklyLimitUSD,
MonthlyLimitUSD: r.MonthlyLimitUSD,
DailyUsageUSD: r.DailyUsageUSD,
WeeklyUsageUSD: r.WeeklyUsageUSD,
MonthlyUsageUSD: r.MonthlyUsageUSD,
DailyWindowStart: r.DailyWindowStart,
WeeklyWindowStart: r.WeeklyWindowStart,
MonthlyWindowStart: r.MonthlyWindowStart,
}
}
return out, nil
}
// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 repo。
func (a *userPlatformQuotaServiceAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error {
repoRecords := make([]UserPlatformQuotaRecord, len(records))
for i, r := range records {
repoRecords[i] = UserPlatformQuotaRecord{
UserID: r.UserID,
Platform: r.Platform,
DailyLimitUSD: r.DailyLimitUSD,
WeeklyLimitUSD: r.WeeklyLimitUSD,
MonthlyLimitUSD: r.MonthlyLimitUSD,
}
}
return a.inner.BulkInsertInitial(ctx, repoRecords)
}
// UpsertForUser 全量替换该用户所有平台限额。
func (a *userPlatformQuotaServiceAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
repoRecords := toRepoRecords(records)
return a.inner.UpsertForUser(ctx, userID, repoRecords)
}
// ResetExpiredWindow 转发至 repository.ResetExpiredWindow并将 repository sentinel 包装为 service sentinel。
func (a *userPlatformQuotaServiceAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart)
if errors.Is(err, ErrUserPlatformQuotaNotFound) {
return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err)
}
return err
}
// genericUserPlatformQuotaAdapter 通过通用接口适配(用于测试 fake 或非标准实现)。
type genericUserPlatformQuotaAdapter struct {
inner UserPlatformQuotaRepository
}
func (a *genericUserPlatformQuotaAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) {
rec, err := a.inner.GetByUserPlatform(ctx, userID, platform)
if err != nil || rec == nil {
return nil, err
}
return toServiceRecord(rec), nil
}
// IncrementUsageWithReset 原子累加 cost通用 adapter 实现)。
func (a *genericUserPlatformQuotaAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now)
}
// ListByUser 查询用户的所有平台配额记录(通用 adapter 实现)。
func (a *genericUserPlatformQuotaAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) {
rows, err := a.inner.ListByUser(ctx, userID)
if err != nil {
return nil, err
}
out := make([]service.UserPlatformQuotaRecord, len(rows))
for i, r := range rows {
out[i] = service.UserPlatformQuotaRecord{
UserID: r.UserID,
Platform: r.Platform,
DailyLimitUSD: r.DailyLimitUSD,
WeeklyLimitUSD: r.WeeklyLimitUSD,
MonthlyLimitUSD: r.MonthlyLimitUSD,
DailyUsageUSD: r.DailyUsageUSD,
WeeklyUsageUSD: r.WeeklyUsageUSD,
MonthlyUsageUSD: r.MonthlyUsageUSD,
DailyWindowStart: r.DailyWindowStart,
WeeklyWindowStart: r.WeeklyWindowStart,
MonthlyWindowStart: r.MonthlyWindowStart,
}
}
return out, nil
}
// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 generic repo。
func (a *genericUserPlatformQuotaAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error {
repoRecords := make([]UserPlatformQuotaRecord, len(records))
for i, r := range records {
repoRecords[i] = UserPlatformQuotaRecord{
UserID: r.UserID,
Platform: r.Platform,
DailyLimitUSD: r.DailyLimitUSD,
WeeklyLimitUSD: r.WeeklyLimitUSD,
MonthlyLimitUSD: r.MonthlyLimitUSD,
}
}
return a.inner.BulkInsertInitial(ctx, repoRecords)
}
// UpsertForUser 全量替换(通用 adapter 实现)。
func (a *genericUserPlatformQuotaAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
repoRecords := toRepoRecords(records)
return a.inner.UpsertForUser(ctx, userID, repoRecords)
}
// ResetExpiredWindow 转发至 repository.ResetExpiredWindow通用 adapter并包装 sentinel。
func (a *genericUserPlatformQuotaAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart)
if errors.Is(err, ErrUserPlatformQuotaNotFound) {
return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err)
}
return err
}
// toServiceRecord 将 repository.UserPlatformQuotaRecord 转换为 service.UserPlatformQuotaRecord。
func toServiceRecord(rec *UserPlatformQuotaRecord) *service.UserPlatformQuotaRecord {
return &service.UserPlatformQuotaRecord{
UserID: rec.UserID,
Platform: rec.Platform,
DailyLimitUSD: rec.DailyLimitUSD,
WeeklyLimitUSD: rec.WeeklyLimitUSD,
MonthlyLimitUSD: rec.MonthlyLimitUSD,
DailyUsageUSD: rec.DailyUsageUSD,
WeeklyUsageUSD: rec.WeeklyUsageUSD,
MonthlyUsageUSD: rec.MonthlyUsageUSD,
DailyWindowStart: rec.DailyWindowStart,
WeeklyWindowStart: rec.WeeklyWindowStart,
MonthlyWindowStart: rec.MonthlyWindowStart,
}
}
// toRepoRecords 将 service.UserPlatformQuotaRecord 切片转换为 repository.UserPlatformQuotaRecord含 limit 字段,含 usage/window_start
func toRepoRecords(records []service.UserPlatformQuotaRecord) []UserPlatformQuotaRecord {
out := make([]UserPlatformQuotaRecord, len(records))
for i, r := range records {
out[i] = UserPlatformQuotaRecord{
UserID: r.UserID,
Platform: r.Platform,
DailyLimitUSD: r.DailyLimitUSD,
WeeklyLimitUSD: r.WeeklyLimitUSD,
MonthlyLimitUSD: r.MonthlyLimitUSD,
DailyUsageUSD: r.DailyUsageUSD,
WeeklyUsageUSD: r.WeeklyUsageUSD,
MonthlyUsageUSD: r.MonthlyUsageUSD,
DailyWindowStart: r.DailyWindowStart,
WeeklyWindowStart: r.WeeklyWindowStart,
MonthlyWindowStart: r.MonthlyWindowStart,
}
}
return out
}

View File

@ -0,0 +1,148 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
"github.com/stretchr/testify/require"
)
func TestUpsertForUser_NewUserInsertsAllRecords(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
daily := 10.0
weekly := 50.0
monthly := 200.0
records := []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily, WeeklyLimitUSD: &weekly, MonthlyLimitUSD: &monthly},
{UserID: userID, Platform: "openai", DailyLimitUSD: &daily},
}
require.NoError(t, repo.UpsertForUser(ctx, userID, records))
got, err := repo.ListByUser(ctx, userID)
require.NoError(t, err)
require.Len(t, got, 2)
}
func TestUpsertForUser_PartialUpdateSoftDeletesMissingPlatforms(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
d1 := 10.0
d2 := 20.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1},
{UserID: userID, Platform: "openai", DailyLimitUSD: &d1},
}))
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2},
{UserID: userID, Platform: "gemini", DailyLimitUSD: &d1},
}))
active, err := repo.ListByUser(ctx, userID)
require.NoError(t, err)
platforms := map[string]float64{}
for _, r := range active {
require.NotNil(t, r.DailyLimitUSD)
platforms[r.Platform] = *r.DailyLimitUSD
}
require.Len(t, platforms, 2)
require.InDelta(t, 20.0, platforms["anthropic"], 1e-9)
require.InDelta(t, 10.0, platforms["gemini"], 1e-9)
_, openaiActive := platforms["openai"]
require.False(t, openaiActive, "openai should be soft-deleted")
}
func TestUpsertForUser_PreservesUsageAndWindowStart(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
d := 10.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
}))
now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC)
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.5, now))
newD := 50.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &newD},
}))
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.NotNil(t, rec)
require.InDelta(t, 50.0, *rec.DailyLimitUSD, 1e-9, "limit should update")
require.InDelta(t, 3.5, rec.DailyUsageUSD, 1e-9, "usage must be preserved")
require.NotNil(t, rec.DailyWindowStart, "window_start must be preserved")
}
func TestUpsertForUser_ReactivatesSoftDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
d := 10.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
}))
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{}))
gone, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.Nil(t, gone, "anthropic should be soft-deleted (not active)")
d2 := 20.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2},
}))
back, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
require.NoError(t, err)
require.NotNil(t, back, "anthropic should be active again")
require.InDelta(t, 20.0, *back.DailyLimitUSD, 1e-9)
allRows, err := client.UserPlatformQuota.Query().
Where(userplatformquota.UserIDEQ(userID), userplatformquota.PlatformEQ("anthropic")).
All(ctx)
require.NoError(t, err)
activeCount := 0
for _, r := range allRows {
if r.DeletedAt == nil {
activeCount++
}
}
require.Equal(t, 1, activeCount, "should have exactly one active row")
}
func TestUpsertForUser_EmptyClearsAll(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
userID := mustCreateUserForQuota(t, client)
repo := NewUserPlatformQuotaRepository(client)
d := 10.0
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
{UserID: userID, Platform: "openai", DailyLimitUSD: &d},
}))
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{}))
got, err := repo.ListByUser(ctx, userID)
require.NoError(t, err)
require.Empty(t, got)
}

View File

@ -93,6 +93,8 @@ var ProviderSet = wire.NewSet(
NewChannelMonitorRequestTemplateRepository,
NewContentModerationRepository,
NewAffiliateRepository,
NewUserPlatformQuotaRepository, // T14: user × platform quota
NewUserPlatformQuotaServiceAdapter, // T14: adapter → service.UserPlatformQuotaRepository
// Cache implementations
NewGatewayCache,

View File

@ -798,6 +798,14 @@ func TestAPIContracts(t *testing.T) {
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}},
"auth_source_default_email_platform_quotas": null,
"auth_source_default_github_platform_quotas": null,
"auth_source_default_google_platform_quotas": null,
"auth_source_default_linuxdo_platform_quotas": null,
"auth_source_default_oidc_platform_quotas": null,
"auth_source_default_wechat_platform_quotas": null,
"auth_source_default_dingtalk_platform_quotas": null,
"affiliate_rebate_rate": 20,
"affiliate_rebate_freeze_hours": 0,
"affiliate_rebate_duration_days": 0,
@ -1025,6 +1033,14 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50],
"default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}},
"auth_source_default_email_platform_quotas": null,
"auth_source_default_github_platform_quotas": null,
"auth_source_default_google_platform_quotas": null,
"auth_source_default_linuxdo_platform_quotas": null,
"auth_source_default_oidc_platform_quotas": null,
"auth_source_default_wechat_platform_quotas": null,
"auth_source_default_dingtalk_platform_quotas": null,
"custom_menu_items": [],
"custom_endpoints": [],
"default_concurrency": 0,

View File

@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,

View File

@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}

View File

@ -241,6 +241,9 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
users.GET("/:id/platform-quotas", h.Admin.User.GetUserPlatformQuotas)
users.PUT("/:id/platform-quotas", h.Admin.User.UpdateUserPlatformQuotas)
users.POST("/:id/platform-quotas/reset", h.Admin.User.ResetUserPlatformQuotaWindow)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)

View File

@ -32,6 +32,7 @@ func RegisterUserRoutes(
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage)
user.GET("/platform-quotas", h.User.GetMyPlatformQuotas)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")

View File

@ -459,6 +459,22 @@ func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID
panic("unexpected InvalidateAPIKeyRateLimit call")
}
func (s *billingCacheStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
panic("unexpected GetUserPlatformQuotaCache call")
}
func (s *billingCacheStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
panic("unexpected SetUserPlatformQuotaCache call")
}
func (s *billingCacheStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
panic("unexpected DeleteUserPlatformQuotaCache call")
}
func (s *billingCacheStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
panic("unexpected IncrUserPlatformQuotaUsageCache call")
}
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
t.Helper()
calls := make([]subscriptionInvalidateCall, 0, expected)

View File

@ -189,6 +189,8 @@ func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username,
}
s.postAuthUserBootstrap(ctx, user, providerType, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {

View File

@ -0,0 +1,88 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func newEmailOAuthAutoAuthService(
userRepo UserRepository,
settings map[string]string,
quotaRepo UserPlatformQuotaRepository,
) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
return NewAuthService(
nil, // entClient — nil, updateUserSignupSource early return
userRepo,
nil, // redeemRepo — invitationCode="" 时不触发
&refreshTokenCacheStub{},
cfg,
settingService,
nil, // emailService
nil, // turnstileService
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner — nil, assignSubscriptions early return
nil, // affiliateService — nil, bindOAuthAffiliate early return
quotaRepo,
)
}
func TestEmailOAuthAuto_SnapshotsPlatformQuotaDefaults(t *testing.T) {
userRepo := &userRepoStub{nextID: 88}
quotaRepo := &userPlatformQuotaRepoStub{}
svc := newEmailOAuthAutoAuthService(
userRepo,
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultPlatformQuotas: `{"gemini": {"monthly": 100.0}}`,
},
quotaRepo,
)
user, err := svc.createEmailOAuthUser(
context.Background(),
"newoauth@example.com",
"newoauth",
"github",
"", // invitationCode
"", // affiliateCode
)
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(88), user.ID)
require.Len(t, quotaRepo.bulkInsertCalls, 1, "createEmailOAuthUser must snapshot platform quotas via BulkInsertInitial")
records := quotaRepo.bulkInsertCalls[0]
var geminiRecord *UserPlatformQuotaRecord
for i := range records {
if records[i].Platform == "gemini" {
geminiRecord = &records[i]
break
}
}
require.NotNil(t, geminiRecord, "expected gemini platform record")
require.NotNil(t, geminiRecord.MonthlyLimitUSD)
require.InDelta(t, 100.0, *geminiRecord.MonthlyLimitUSD, 0.0001)
}

View File

@ -283,6 +283,8 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
return nil
}

View File

@ -112,6 +112,7 @@ func newOAuthEmailFlowAuthService(
refreshTokenCache RefreshTokenCache,
settings map[string]string,
emailCache EmailCache,
quotaRepo UserPlatformQuotaRepository, // 新增
) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
@ -142,6 +143,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
quotaRepo, // 替换原来的 nil
)
}
@ -175,6 +177,7 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
nil,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
@ -215,6 +218,7 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
nil,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
@ -274,6 +278,7 @@ func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T)
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
nil,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
@ -313,6 +318,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
nil,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
@ -360,6 +366,7 @@ func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T)
SettingKeyInvitationCodeEnabled: "true",
},
&emailCacheStub{},
nil,
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
@ -382,6 +389,7 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
SettingKeyRegistrationEnabled: "true",
},
&emailCacheStub{},
nil,
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
@ -389,3 +397,54 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "delete created oauth user")
}
func TestFinalizeOAuthEmailAccount_SnapshotsPlatformQuotaDefaults(t *testing.T) {
userRepo := &userRepoStub{nextID: 99}
quotaRepo := &userPlatformQuotaRepoStub{}
authService := newOAuthEmailFlowAuthService(
userRepo,
nil,
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 5.5}}`,
},
&emailCacheStub{},
quotaRepo,
)
user := &User{
ID: 99,
Email: "newuser@example.com",
Role: RoleUser,
Status: StatusActive,
SignupSource: "oidc",
}
err := authService.FinalizeOAuthEmailAccount(
context.Background(),
user,
"",
"oidc",
"",
)
require.NoError(t, err)
require.Len(t, quotaRepo.bulkInsertCalls, 1, "snapshotPlatformQuotaDefaults must call BulkInsertInitial once on successful OAuth signup")
records := quotaRepo.bulkInsertCalls[0]
var anthropicRecord *UserPlatformQuotaRecord
for i := range records {
if records[i].Platform == "anthropic" {
anthropicRecord = &records[i]
break
}
}
require.NotNil(t, anthropicRecord, "expected anthropic platform record")
require.Equal(t, int64(99), anthropicRecord.UserID)
require.NotNil(t, anthropicRecord.DailyLimitUSD)
require.InDelta(t, 5.5, *anthropicRecord.DailyLimitUSD, 0.0001)
}

View File

@ -62,18 +62,19 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
entClient *dbent.Client
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
entClient *dbent.Client
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
userPlatformQuotaRepo UserPlatformQuotaRepository
}
type DefaultSubscriptionAssigner interface {
@ -81,9 +82,10 @@ type DefaultSubscriptionAssigner interface {
}
type signupGrantPlan struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
PlatformQuotas map[string]*DefaultPlatformQuotaSetting
}
// NewAuthService 创建认证服务实例
@ -100,20 +102,22 @@ func NewAuthService(
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
userPlatformQuotaRepo UserPlatformQuotaRepository,
) *AuthService {
return &AuthService{
entClient: entClient,
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
entClient: entClient,
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
userPlatformQuotaRepo: userPlatformQuotaRepo,
}
}
@ -226,6 +230,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
if s.affiliateService != nil {
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
@ -535,6 +541,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@ -685,6 +693,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
@ -703,6 +713,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// snapshot user × platform quotafail-open
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@ -764,18 +776,39 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
// ============ 全局 quota 装载(必须在 ResolveAuthSourceGrantSettings 之前) ============
// 无论 auth source 是否 enabled全局层都要先装载确保 !enabled 早退路径也携带全局 quota。
if quotas, err := s.settingService.GetDefaultPlatformQuotas(ctx); err == nil {
plan.PlatformQuotas = quotas
} else {
logger.LegacyPrintf("service.auth", "[Auth] Warning: load default platform quotas failed: %v (fail-open)", err)
}
// ============================================================================================
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
if !enabled {
return plan
return plan // plan.PlatformQuotas 已含全局层
}
plan.Balance = resolved.Balance
plan.Concurrency = resolved.Concurrency
plan.Subscriptions = resolved.Subscriptions
// ============ auth source quota merge仅在 enabled 分支内) ============
asQuotas := s.settingService.GetAuthSourcePlatformQuotas(ctx, signupSource)
if plan.PlatformQuotas != nil {
for platform, patch := range asQuotas {
if dst := plan.PlatformQuotas[platform]; dst != nil {
mergePlatformQuotaDefaults(dst, patch)
}
}
}
// ==============================================================================
return plan
}
@ -1586,3 +1619,29 @@ func resolvedTokenVersion(user *User) int64 {
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}
// snapshotPlatformQuotaDefaults 把 plan.PlatformQuotas4 platform × 3 window
// BulkInsertInitial 形式写入 user_platform_quotas 表。失败 fail-open仅 warn log
func (s *AuthService) snapshotPlatformQuotaDefaults(ctx context.Context, userID int64, plan *signupGrantPlan) error {
if s.userPlatformQuotaRepo == nil || plan == nil || len(plan.PlatformQuotas) == 0 {
return nil
}
records := make([]UserPlatformQuotaRecord, 0, len(plan.PlatformQuotas))
for platform, q := range plan.PlatformQuotas {
rec := UserPlatformQuotaRecord{
UserID: userID,
Platform: platform,
}
if q != nil {
rec.DailyLimitUSD = q.DailyLimitUSD
rec.WeeklyLimitUSD = q.WeeklyLimitUSD
rec.MonthlyLimitUSD = q.MonthlyLimitUSD
}
records = append(records, rec)
}
if err := s.userPlatformQuotaRepo.BulkInsertInitial(ctx, records); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Warning: snapshot platform quota failed user=%d: %v (fail-open)", userID, err)
return nil // fail-open返回 nil让调用方继续
}
return nil
}

View File

@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil, nil)
return svc, repo, client
}
@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
@ -810,8 +810,8 @@ func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, t
}
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
s.mu.Lock()
@ -820,8 +820,12 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
return ok, nil
}
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
return 0, nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil

View File

@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil, nil)
return svc, repo, client
}

View File

@ -0,0 +1,157 @@
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"time"
)
// fakeInsertRecorder 记录 BulkInsertInitial 调用,实现 UserPlatformQuotaRepository port。
type fakeInsertRecorder struct {
records []UserPlatformQuotaRecord
err error
}
func (f *fakeInsertRecorder) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
return nil, nil
}
func (f *fakeInsertRecorder) BulkInsertInitial(_ context.Context, recs []UserPlatformQuotaRecord) error {
if f.err != nil {
return f.err
}
f.records = append(f.records, recs...)
return nil
}
func (f *fakeInsertRecorder) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
return nil
}
func (f *fakeInsertRecorder) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
return nil, nil
}
func (f *fakeInsertRecorder) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error {
return nil
}
func (f *fakeInsertRecorder) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error {
return nil
}
func TestSnapshotPlatformQuotaDefaults_PassesToRepoBulkInsert(t *testing.T) {
fakeRepo := &fakeInsertRecorder{}
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
five := 5.0
plan := &signupGrantPlan{
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
"openai": {},
"gemini": {},
"antigravity": {},
},
}
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 999, plan); err != nil {
t.Fatal(err)
}
if len(fakeRepo.records) != 4 {
t.Errorf("expected 4 records, got %d", len(fakeRepo.records))
}
found := false
for _, r := range fakeRepo.records {
if r.UserID == 999 && r.Platform == "anthropic" && r.DailyLimitUSD != nil && *r.DailyLimitUSD == 5 {
found = true
}
}
if !found {
t.Error("anthropic daily = 5 not snapshotted")
}
}
func TestSnapshotPlatformQuotaDefaults_NilPlanIsNoop(t *testing.T) {
fakeRepo := &fakeInsertRecorder{}
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, nil); err != nil {
t.Errorf("nil plan should be noop, got %v", err)
}
if len(fakeRepo.records) != 0 {
t.Errorf("expected no records, got %d", len(fakeRepo.records))
}
}
func TestSnapshotPlatformQuotaDefaults_RepoErrorFailsOpen(t *testing.T) {
fakeRepo := &fakeInsertRecorder{err: fmt.Errorf("db down")}
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
five := 5.0
plan := &signupGrantPlan{
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{
"anthropic": {DailyLimitUSD: &five},
},
}
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil {
t.Errorf("fail-open: expected nil even on repo error, got %v", err)
}
}
func TestSnapshotPlatformQuotaDefaults_NilRepoIsNoop(t *testing.T) {
s := &AuthService{userPlatformQuotaRepo: nil}
five := 5.0
plan := &signupGrantPlan{
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{"a": {DailyLimitUSD: &five}},
}
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil {
t.Errorf("nil repo should be noop, got %v", err)
}
}
// resolveSignupGrantPlan 测试:依赖完整的 AuthService 构造,需要 SettingService含 settingRepoStub
// settingRepoStub 已在 auth_service_register_test.go 中定义,同 package 可直接使用。
func TestResolveSignupGrantPlan_GlobalQuotaLoadedBeforeAuthSource(t *testing.T) {
// 全局 quota JSON key新格式
settings := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultPlatformQuotas: `{
"anthropic": {"daily": 10, "weekly": 50, "monthly": 200},
"openai": {"daily": 5, "weekly": 25, "monthly": 100},
"gemini": {"daily": 5, "weekly": 25, "monthly": 100},
"antigravity": {"daily": 5, "weekly": 25, "monthly": 100}
}`,
}
svc := newAuthService(nil, settings, nil, nil)
plan := svc.resolveSignupGrantPlan(context.Background(), "email")
if plan.PlatformQuotas == nil {
t.Fatal("expected PlatformQuotas to be non-nil after loading global quota KVs")
}
q := plan.PlatformQuotas["anthropic"]
if q == nil {
t.Fatal("expected anthropic quota to be set")
}
if q.DailyLimitUSD == nil || *q.DailyLimitUSD != 10 {
t.Errorf("expected anthropic daily=10, got %v", q.DailyLimitUSD)
}
}
// TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota 验证 P1 约束:
// !enabled 早退路径仍携带全局 quotaGetDefaultPlatformQuotas 在 ResolveAuthSourceGrantSettings 之前)。
func TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota(t *testing.T) {
settings := map[string]string{
SettingKeyRegistrationEnabled: "true",
// auth source 不配置(=> !enabled 路径)
SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 10, "weekly": 50, "monthly": 200}}`,
}
svc := newAuthService(nil, settings, nil, nil)
plan := svc.resolveSignupGrantPlan(context.Background(), "email")
// !enabled 路径plan.PlatformQuotas 应已含全局层(不是 nil
if plan.PlatformQuotas == nil {
t.Fatal("P1 violated: PlatformQuotas is nil even with global quota KVs set")
}
// P1 核心断言disabled auth source 路径不能丢失全局 quota
if _, ok := plan.PlatformQuotas["anthropic"]; !ok {
t.Error("P1 violated: disabled auth source path dropped global platform quota")
}
}

View File

@ -73,6 +73,38 @@ type defaultSubscriptionAssignerStub struct {
type refreshTokenCacheStub struct{}
type userPlatformQuotaRepoStub struct {
bulkInsertCalls [][]UserPlatformQuotaRecord
bulkInsertErr error
}
func (s *userPlatformQuotaRepoStub) BulkInsertInitial(_ context.Context, records []UserPlatformQuotaRecord) error {
cloned := make([]UserPlatformQuotaRecord, len(records))
copy(cloned, records)
s.bulkInsertCalls = append(s.bulkInsertCalls, cloned)
return s.bulkInsertErr
}
func (s *userPlatformQuotaRepoStub) GetByUserPlatform(context.Context, int64, string) (*UserPlatformQuotaRecord, error) {
panic("unexpected GetByUserPlatform call")
}
func (s *userPlatformQuotaRepoStub) ListByUser(context.Context, int64) ([]UserPlatformQuotaRecord, error) {
panic("unexpected ListByUser call")
}
func (s *userPlatformQuotaRepoStub) IncrementUsageWithReset(context.Context, int64, string, float64, time.Time) error {
panic("unexpected IncrementUsageWithReset call")
}
func (s *userPlatformQuotaRepoStub) UpsertForUser(context.Context, int64, []UserPlatformQuotaRecord) error {
panic("unexpected UpsertForUser call")
}
func (s *userPlatformQuotaRepoStub) ResetExpiredWindow(context.Context, int64, string, string, time.Time) error {
panic("unexpected ResetExpiredWindow call")
}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@ -178,7 +210,7 @@ func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int6
return 0, nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache, quotaRepo UserPlatformQuotaRepository) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
@ -213,6 +245,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
quotaRepo,
)
}
@ -220,7 +253,7 @@ func TestAuthService_Register_Disabled(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "false",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
@ -229,19 +262,62 @@ func TestAuthService_Register_Disabled(t *testing.T) {
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil设置项不存在注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, nil, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_SnapshotsPlatformQuotaDefaults(t *testing.T) {
repo := &userRepoStub{nextID: 77}
quotaRepo := &userPlatformQuotaRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultPlatformQuotas: `{"openai": {"weekly": 12.34}}`,
}, nil, quotaRepo)
_, user, err := service.Register(context.Background(), "newuser@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Len(t, quotaRepo.bulkInsertCalls, 1)
records := quotaRepo.bulkInsertCalls[0]
var openaiRecord *UserPlatformQuotaRecord
for i := range records {
if records[i].Platform == "openai" {
openaiRecord = &records[i]
break
}
}
require.NotNil(t, openaiRecord, "expected openai platform record")
require.Equal(t, int64(77), openaiRecord.UserID)
require.NotNil(t, openaiRecord.WeeklyLimitUSD)
require.InDelta(t, 12.34, *openaiRecord.WeeklyLimitUSD, 0.0001)
}
func TestAuthService_Register_DoesNotSnapshotOnDisabled(t *testing.T) {
repo := &userRepoStub{}
quotaRepo := &userPlatformQuotaRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "false",
}, nil, quotaRepo)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
require.Empty(t, quotaRepo.bulkInsertCalls, "registration rejected before user creation must not snapshot")
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nilemailService 未配置)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, nil)
}, nil, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
@ -254,7 +330,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
}, cache, nil)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
@ -268,7 +344,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
}, cache, nil)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
@ -279,7 +355,7 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
@ -289,7 +365,7 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
@ -299,7 +375,7 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
@ -310,7 +386,7 @@ func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@other.com", "password")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
@ -327,7 +403,7 @@ func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
}, nil)
}, nil, nil)
_, user, err := service.Register(context.Background(), "user@example.com", "password")
require.NoError(t, err)
@ -340,7 +416,7 @@ func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
}, nil, nil)
err := service.SendVerifyCode(context.Background(), "user@other.com")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
@ -354,7 +430,7 @@ func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
@ -365,7 +441,7 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
}, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
@ -376,7 +452,7 @@ func TestAuthService_Register_Success(t *testing.T) {
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
}, nil, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err)
@ -394,7 +470,7 @@ func TestAuthService_Register_Success(t *testing.T) {
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, nil, nil, nil)
// 创建用户并生成 token
user := &User{
@ -436,7 +512,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
TokenVersion: 1,
}
repo := &userRepoStub{user: user}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, nil, nil, nil)
// 创建过期 token
service.cfg.JWT.ExpireHour = -1
@ -453,7 +529,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
}
func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service := newAuthService(&userRepoStub{}, nil, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
@ -461,7 +537,7 @@ func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T)
}
func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service := newAuthService(&userRepoStub{}, nil, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
@ -469,7 +545,7 @@ func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
}
func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service := newAuthService(&userRepoStub{}, nil, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
@ -494,7 +570,7 @@ func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
}
func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service := newAuthService(&userRepoStub{}, nil, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
@ -525,7 +601,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
@ -549,7 +625,7 @@ func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *tes
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
@ -572,7 +648,7 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes
SettingKeyAuthSourceDefaultEmailConcurrency: "88",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-global@test.com", "password")
@ -595,7 +671,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
@ -618,7 +694,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
@ -654,7 +730,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
}, nil)
}, nil, nil)
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
@ -677,7 +753,7 @@ func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.Ding
DingTalk: dtCfg,
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil)
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil, nil)
}
// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig不设 Enabled/BypassRegistration/Policy

View File

@ -55,6 +55,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
nil, // userPlatformQuotaRepo
)
}

View File

@ -11,18 +11,31 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"golang.org/x/sync/singleflight"
)
// 错误定义
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
// errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时
// 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。
var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable")
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
// user × platform quotaHTTP 429 Too Many Requests + Retry-After header
// 选用 429 而非 403限额耗尽属于"暂时性资源用尽,重试可恢复"的场景RFC 6585
// 大量 SDK如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After
// 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。
ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.")
ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.")
ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@ -94,6 +107,7 @@ type BillingCacheService struct {
userGroupRateRepo UserGroupRateRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
userPlatformQuotaRepo UserPlatformQuotaRepository
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
@ -101,6 +115,7 @@ type BillingCacheService struct {
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
quotaLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
@ -117,6 +132,7 @@ func NewBillingCacheService(
userRPMCache UserRPMCache,
userGroupRateRepo UserGroupRateRepository,
cfg *config.Config,
userPlatformQuotaRepo UserPlatformQuotaRepository,
) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
@ -126,6 +142,7 @@ func NewBillingCacheService(
userRPMCache: userRPMCache,
userGroupRateRepo: userGroupRateRepo,
cfg: cfg,
userPlatformQuotaRepo: userPlatformQuotaRepo,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
@ -655,6 +672,30 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co
})
}
// IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。
//
// 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage
// 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。
// 写延迟通常 < 1ms本地 Redis换取 quota 视图实时性的取舍合理。
//
// Redis 写失败用 ALERT 级 logDB 持久化由 caller 单独 goroutine 兜底gateway_service.go
func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) {
if s.cache == nil {
return
}
if platform == "" || cost <= 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil {
logger.LegacyPrintf("service.billing_cache",
"ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v",
userID, platform, cost, err)
}
}
// ============================================
// 统一检查方法
// ============================================
@ -662,7 +703,8 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil
@ -684,6 +726,13 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
// user × platform quota 仅在 standard余额模式生效订阅模式豁免
if !isSubscriptionMode {
if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil {
return err
}
}
// Check API Key rate limits (applies to both billing modes)
if apiKey != nil && apiKey.HasRateLimits() {
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
@ -975,3 +1024,257 @@ func circuitStateString(state billingCircuitBreakerState) string {
return "unknown"
}
}
// checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。
// 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata
// checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。
//
// 流程Redis-first / DB-fallback
// 1. 先读 Redis cache若命中且 SchemaVersion==1直接用 entry 中的 limits 和 window_start 做校验,
// 免除 DB 查询。
// 2. cache MISS 或旧版 entrySchemaVersion==0→ 查 DB 回填完整 entry含 limits/window_start
// 3. Redis 故障err != nil→ fail-open查 DB 做一次性检查,不回填。
func (s *BillingCacheService) checkUserPlatformQuotaEligibility(
ctx context.Context,
userID int64,
platform string,
) error {
if platform == "" || s.userPlatformQuotaRepo == nil {
return nil
}
// cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。
// 其他 check* 方法balance/subscription/rate-limit也有类似守卫。
var (
entry *UserPlatformQuotaCacheEntry
ok bool
cacheErr error
)
if s.cache != nil {
entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform)
} else {
// 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查
cacheErr = errBillingCacheUnavailable
}
// --- cache HIT with current schema → 直接用 entry不查 DB ---
if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 {
now := time.Now()
dailyUsage := entry.DailyUsageUSD
weeklyUsage := entry.WeeklyUsageUSD
monthlyUsage := entry.MonthlyUsageUSD
// 若窗口已更新DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较,
// 同时记录新窗口起点用于后续刷新 cache entry。
// 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力,
// 持久化数据始终正确。
windowExpired := false
newDailyStart := entry.DailyWindowStart
newWeeklyStart := entry.WeeklyWindowStart
newMonthlyStart := entry.MonthlyWindowStart
if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) {
dailyUsage = 0
windowExpired = true
dayStart := timezone.StartOfDay(now)
newDailyStart = &dayStart
}
if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) {
weeklyUsage = 0
windowExpired = true
weekStart := timezone.StartOfWeek(now)
newWeeklyStart = &weekStart
}
if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) {
monthlyUsage = 0
windowExpired = true
monthStart := now
newMonthlyStart = &monthStart
}
// 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis而非 Delete
// 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到
// EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。
// 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。
// 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。
// 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。
// 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。
if windowExpired && s.cache != nil {
refreshed := &UserPlatformQuotaCacheEntry{
DailyUsageUSD: dailyUsage,
WeeklyUsageUSD: weeklyUsage,
MonthlyUsageUSD: monthlyUsage,
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
DailyLimitUSD: entry.DailyLimitUSD,
WeeklyLimitUSD: entry.WeeklyLimitUSD,
MonthlyLimitUSD: entry.MonthlyLimitUSD,
DailyWindowStart: newDailyStart,
WeeklyWindowStart: newWeeklyStart,
MonthlyWindowStart: newMonthlyStart,
}
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil {
logger.LegacyPrintf("service.billing_cache",
"Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v",
userID, platform, setErr)
}
setCancel()
}
if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
}
if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
}
if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now))
}
return nil
}
// --- cache MISS、旧版 entry 或 Redis 故障 → 查 DBsingleflight 合并并发回源)---
// 使用 DoChan 而非 Doavoid sharing the first caller's ctx among all dedupe followers.
// 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。
sfKey := strconv.FormatInt(userID, 10) + ":" + platform
ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) {
// 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx
// 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。
bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer bgCancel()
return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform)
})
var (
v any
dbErr error
)
select {
case res := <-ch:
v, dbErr = res.Val, res.Err
case <-ctx.Done():
// 当前 caller 的 ctx 被取消fail-open不阻断 (此请求已无意义)。
logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err())
return nil
}
if dbErr != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr)
return nil
}
rec, _ := v.(*UserPlatformQuotaRecord)
if rec == nil {
return nil
}
now := time.Now()
dailyUsage := rec.DailyUsageUSD
weeklyUsage := rec.WeeklyUsageUSD
monthlyUsage := rec.MonthlyUsageUSD
if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) {
dailyUsage = 0
}
if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) {
weeklyUsage = 0
}
if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) {
monthlyUsage = 0
}
// Redis 故障时 fail-open不回填直接用 DB 数据做一次性检查
if cacheErr != nil {
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
}
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
}
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
}
return nil
}
// cache MISS 或旧版 entry → 回填完整 entry含 limits 和 window_start
newEntry := &UserPlatformQuotaCacheEntry{
DailyUsageUSD: dailyUsage,
WeeklyUsageUSD: weeklyUsage,
MonthlyUsageUSD: monthlyUsage,
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
DailyLimitUSD: rec.DailyLimitUSD,
WeeklyLimitUSD: rec.WeeklyLimitUSD,
MonthlyLimitUSD: rec.MonthlyLimitUSD,
DailyWindowStart: rec.DailyWindowStart,
WeeklyWindowStart: rec.WeeklyWindowStart,
MonthlyWindowStart: rec.MonthlyWindowStart,
}
if s.cache != nil {
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
// 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms,
// 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败,
// 让下一次 preflight 仍然 MISS 并击穿到 DB高并发下增大 DB 压力)。
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr)
}
setCancel()
}
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
}
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
}
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
}
return nil
}
// withWindowResetsMetadata 给 quota error 附加 window_resets_at metadataRFC3339
func withWindowResetsMetadata(err error, resetAt time.Time) error {
appErr, ok := err.(*infraerrors.ApplicationError)
if !ok || appErr == nil {
return err
}
return appErr.WithMetadata(map[string]string{
"window_resets_at": resetAt.Format(time.RFC3339),
})
}
// nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。
// 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。
func nextDailyReset(now time.Time) time.Time {
return timezone.StartOfDay(now).AddDate(0, 0, 1)
}
// nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。
// 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。
func nextWeeklyReset(now time.Time) time.Time {
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
}
// nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间start + 30d
// start 为 nil未初始化或已过期now-start >= 30d与 monthlyQuotaWindowExpired 同口径)时
// 退化为 now+30d过期窗口会在下次 increment 时重置为 now下次重置即 now+30d
// 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。
func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time {
if start == nil || now.Sub(*start) >= 30*24*time.Hour {
return now.Add(30 * 24 * time.Hour)
}
return start.Add(30 * 24 * time.Hour)
}
// quotaWindowExpired 判断窗口是否已过期start 为 nil未初始化或在 currWindowStart 之前视为已过期。
func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool {
if start == nil {
return true
}
return start.Before(currWindowStart)
}
// monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。
// 过期条件now - start >= 30×24h与订阅模式 NeedsMonthlyReset 语义一致)。
// start 为 nil 时视为已过期(未初始化窗口)。
func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool {
if start == nil {
return true
}
return now.Sub(*start) >= 30*24*time.Hour
}

View File

@ -74,7 +74,7 @@ func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGrou
t.Helper()
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
// 我们只直接测 checkRPM。
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}, nil)
t.Cleanup(svc.Stop)
return svc
}

View File

@ -67,6 +67,22 @@ func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, ke
return nil
}
func (s *billingCacheMissStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
return nil, false, nil
}
func (s *billingCacheMissStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
return nil
}
func (s *billingCacheMissStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
return nil
}
func (s *billingCacheMissStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
@ -100,7 +116,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}, nil)
t.Cleanup(svc.Stop)
const goroutines = 16

View File

@ -68,9 +68,25 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
return nil
}
func (b *billingCacheWorkerStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
return nil, false, nil
}
func (b *billingCacheWorkerStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
return nil
}
func (b *billingCacheWorkerStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
return nil
}
func (b *billingCacheWorkerStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil)
t.Cleanup(svc.Stop)
start := time.Now()
@ -92,7 +108,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil)
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{

View File

@ -0,0 +1,595 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
// fakeIncrCache 仅记录 IncrUserPlatformQuotaUsageCache 被调用的参数。
type fakeIncrCache struct {
BillingCache
calls []incrCall
}
type incrCall struct {
userID int64
platform string
cost float64
ttl time.Duration
}
func (f *fakeIncrCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
f.calls = append(f.calls, incrCall{userID, platform, cost, ttl})
return nil
}
// IncrementUserPlatformQuotaUsage 已改为同步直写,不再走 worker。
// 测试验证:同步调用立即调到 cache.IncrUserPlatformQuotaUsageCache。
func TestIncrementUserPlatformQuotaUsage_SyncCallsCache(t *testing.T) {
fake := &fakeIncrCache{}
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 120
s := &BillingCacheService{
cache: fake,
cfg: cfg,
}
s.IncrementUserPlatformQuotaUsage(101, "anthropic", 0.25)
s.IncrementUserPlatformQuotaUsage(101, "openai", 0.50)
if len(fake.calls) != 2 {
t.Fatalf("expected 2 incr calls, got %d", len(fake.calls))
}
if fake.calls[0] != (incrCall{101, "anthropic", 0.25, 120 * time.Second}) {
t.Errorf("call[0] = %+v", fake.calls[0])
}
if fake.calls[1] != (incrCall{101, "openai", 0.50, 120 * time.Second}) {
t.Errorf("call[1] = %+v", fake.calls[1])
}
}
// ── T6 tests: checkUserPlatformQuotaEligibility ──────────────────────────────
// fakeQuotaRepo 实现 UserPlatformQuotaRepository 最小子集
type fakeQuotaRepo struct {
rec *UserPlatformQuotaRecord
}
func (f *fakeQuotaRepo) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
return f.rec, nil
}
func (f *fakeQuotaRepo) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error {
return nil
}
func (f *fakeQuotaRepo) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
return nil
}
func (f *fakeQuotaRepo) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
return nil, nil
}
func (f *fakeQuotaRepo) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error {
return nil
}
func (f *fakeQuotaRepo) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error {
return nil
}
// fakeFullCache 同时支持 Get + Set + Incr + Delete。
// mu 保护 entry 和 deleteCalls防止异步 goroutine 与主 goroutine 之间的 data race。
type fakeFullCache struct {
BillingCache
mu sync.Mutex
entry *UserPlatformQuotaCacheEntry
deleteCalls int
}
// getDeleteCalls 线程安全地读取 deleteCalls。
func (f *fakeFullCache) getDeleteCalls() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.deleteCalls
}
// getEntry 线程安全地读取 entry。
func (f *fakeFullCache) getEntry() *UserPlatformQuotaCacheEntry {
f.mu.Lock()
defer f.mu.Unlock()
return f.entry
}
func (f *fakeFullCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.entry == nil {
return nil, false, nil
}
return f.entry, true, nil
}
func (f *fakeFullCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, e *UserPlatformQuotaCacheEntry, _ time.Duration) error {
f.mu.Lock()
defer f.mu.Unlock()
f.entry = e
return nil
}
func (f *fakeFullCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error {
f.mu.Lock()
defer f.mu.Unlock()
f.deleteCalls++
f.entry = nil
return nil
}
func newServiceForPreflight(t *testing.T, repo UserPlatformQuotaRepository, cache BillingCache) *BillingCacheService {
t.Helper()
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
return &BillingCacheService{
cache: cache,
cfg: cfg,
userPlatformQuotaRepo: repo,
}
}
// currentDayStart 返回全局时区当天 0 点(与生产 timezone.StartOfDay 同口径,确保窗口有效)。
func currentDayStart() *time.Time {
s := timezone.StartOfDay(time.Now())
return &s
}
func TestCheckUserPlatformQuotaEligibility_AllowsWhenUnderLimit(t *testing.T) {
daily := 10.0
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 4.5,
DailyLimitUSD: &daily,
DailyWindowStart: currentDayStart(),
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}}
s := newServiceForPreflight(t, repo, cache)
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
t.Errorf("expected nil, got %v", err)
}
}
func TestCheckUserPlatformQuotaEligibility_DailyExhausted(t *testing.T) {
daily := 5.0
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 5.0,
DailyLimitUSD: &daily,
DailyWindowStart: currentDayStart(),
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}}
s := newServiceForPreflight(t, repo, cache)
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
t.Errorf("expected ErrUserPlatformDailyQuotaExhausted, got %v", err)
}
}
func TestCheckUserPlatformQuotaEligibility_NilLimitMeansUnlimited(t *testing.T) {
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic",
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 999,
DailyWindowStart: currentDayStart(),
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
// DailyLimitUSD nil → 无限额
}}
s := newServiceForPreflight(t, repo, cache)
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
t.Errorf("nil limits should be unlimited, got %v", err)
}
}
func TestCheckUserPlatformQuotaEligibility_ZeroLimitImmediateBlock(t *testing.T) {
zero := 0.0
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &zero,
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 0,
DailyLimitUSD: &zero,
DailyWindowStart: currentDayStart(),
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}}
s := newServiceForPreflight(t, repo, cache)
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
t.Errorf("expected daily exhausted for limit=0, got %v", err)
}
}
func TestCheckUserPlatformQuotaEligibility_NoRecordMeansUnlimited(t *testing.T) {
repo := &fakeQuotaRepo{rec: nil}
cache := &fakeFullCache{}
s := newServiceForPreflight(t, repo, cache)
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
t.Errorf("no record = unlimited, got %v", err)
}
}
// TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB 验证旧版 entrySchemaVersion=0
// 触发 DB 回退路径,并在 DB 数据判断配额是否超限。
// DB record 需设置有效的 window_start否则 quotaWindowExpired 会将 usage 归零nil 窗口视为已过期)。
func TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB(t *testing.T) {
daily := 5.0
dayStart := currentDayStart()
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, DailyUsageUSD: 6.0,
DailyWindowStart: dayStart,
}}
// SchemaVersion=0旧 entry应走 DB 路径
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{DailyUsageUSD: 1.0}}
s := newServiceForPreflight(t, repo, cache)
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
t.Errorf("旧版 entry 应走 DB 路径并报 daily exhausted, got %v", err)
}
}
// TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache 验证 cache HIT 时若窗口已过期usage 归零,用户放行。
func TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache(t *testing.T) {
daily := 5.0
past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) // 远古窗口起始,肯定已过期
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 10.0, // 超限,但窗口已过期
DailyLimitUSD: &daily,
DailyWindowStart: &past,
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}}
s := newServiceForPreflight(t, repo, cache)
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if err != nil {
t.Errorf("过期窗口应归零放行, got %v", err)
}
}
// TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache 验证:
// V1 HIT 路径检测到窗口过期时,用 reset 后的 entry 同步覆盖 Redis(而非 Delete):
// 1. 当前请求以本地清零值判断 → 放行
// 2. cache entry 被替换为新 entry: usage 清零 + window_start 更新到当前窗口
// limit 保留;这样并发 IncrUserPlatformQuotaUsage 的 Lua INCR 可正确累加到新窗口。
func TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache(t *testing.T) {
daily := 5.0
// 远古窗口起始,确保 quotaWindowExpired 返回 true
past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
}}
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 10.0, // 超限,但窗口已过期 → 应被本地清零后放行
DailyLimitUSD: &daily,
DailyWindowStart: &past,
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}}
s := newServiceForPreflight(t, repo, cache)
// 本次 check 应放行(本地清零后 usage=0 < limit=5)
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if err != nil {
t.Errorf("过期窗口应归零放行, got %v", err)
}
// 验证 cache entry 已被刷新:usage 清零、limit 保留、window_start 更新到当前窗口
refreshed := cache.getEntry()
if refreshed == nil {
t.Fatal("窗口过期后 cache entry 不应为 nil(应被 SetCache 覆盖,而非 Delete)")
}
if refreshed.DailyUsageUSD != 0 {
t.Errorf("刷新后 DailyUsageUSD = %v, want 0", refreshed.DailyUsageUSD)
}
if refreshed.DailyLimitUSD == nil || *refreshed.DailyLimitUSD != daily {
t.Errorf("刷新后 DailyLimitUSD = %v, want %v(保留)", refreshed.DailyLimitUSD, daily)
}
if refreshed.SchemaVersion != UserPlatformQuotaCacheSchemaV1 {
t.Errorf("刷新后 SchemaVersion = %d, want V1", refreshed.SchemaVersion)
}
if refreshed.DailyWindowStart == nil || refreshed.DailyWindowStart.Equal(past) {
t.Errorf("刷新后 DailyWindowStart = %v, 应更新到当前窗口而非保留 past=%v", refreshed.DailyWindowStart, past)
}
}
// ── T5 tests: QueueUpdateUserPlatformQuotaUsage ───────────────────────────────
// ── C-NEW-1: monthlyQuotaWindowExpired 30 天滚动测试 ─────────────────────────
func TestMonthlyQuotaWindowExpired_NilStart(t *testing.T) {
if !monthlyQuotaWindowExpired(nil, time.Now().UTC()) {
t.Error("nil start should be considered expired")
}
}
func TestMonthlyQuotaWindowExpired_Expired(t *testing.T) {
start := time.Now().UTC().Add(-30 * 24 * time.Hour)
if !monthlyQuotaWindowExpired(&start, time.Now().UTC()) {
t.Error("start exactly 30 days ago should be expired")
}
}
func TestMonthlyQuotaWindowExpired_Active(t *testing.T) {
start := time.Now().UTC().Add(-29 * 24 * time.Hour)
if monthlyQuotaWindowExpired(&start, time.Now().UTC()) {
t.Error("start 29 days ago should NOT be expired")
}
}
// TestMonthlyQuotaWindowExpired_CrossMonthBoundary 验证跨自然月时 30 天未满不视为过期。
func TestMonthlyQuotaWindowExpired_CrossMonthBoundary(t *testing.T) {
// 窗口起始 4 月 20 日5 月 1 日只过了 11 天,不足 30 天
start := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
if monthlyQuotaWindowExpired(&start, now) {
t.Error("11 days into window should NOT be expired (30-day rolling, not calendar month)")
}
}
// TestNextMonthlyResetFrom 验证 30 天滚动重置时间计算。
func TestNextMonthlyResetFrom_WithStart(t *testing.T) {
start := time.Date(2026, 5, 1, 10, 0, 0, 0, time.UTC)
want := start.Add(30 * 24 * time.Hour)
now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
got := nextMonthlyResetFrom(&start, now)
if !got.Equal(want) {
t.Errorf("nextMonthlyResetFrom = %v, want %v", got, want)
}
}
func TestNextMonthlyResetFrom_NilStart(t *testing.T) {
now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
got := nextMonthlyResetFrom(nil, now)
want := now.Add(30 * 24 * time.Hour)
if !got.Equal(want) {
t.Errorf("nextMonthlyResetFrom(nil) = %v, want now+30d=%v", got, want)
}
}
func TestNextMonthlyResetFrom_NilStart_NotEqualToNow(t *testing.T) {
now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC)
got := nextMonthlyResetFrom(nil, now)
want := now.Add(30 * 24 * time.Hour)
if !got.Equal(want) {
t.Errorf("nextMonthlyResetFrom(nil) = %v, want %v (now+30d)", got, want)
}
if got.Equal(now) {
t.Error("nextMonthlyResetFrom(nil) must not return now (should be now+30d)")
}
}
// TestNextMonthlyResetFrom_ExpiredStart 验证窗口已过期now-start >= 30d
// 下次重置时间为 now+30d而非 start+30d后者已是过去时间会让 Retry-After 落回
// fallback 并触发客户端紧凑重试)。
func TestNextMonthlyResetFrom_ExpiredStart(t *testing.T) {
start := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) // 距 start 61 天,已过期
got := nextMonthlyResetFrom(&start, now)
want := now.Add(30 * 24 * time.Hour)
if !got.Equal(want) {
t.Errorf("nextMonthlyResetFrom(expired) = %v, want now+30d=%v", got, want)
}
if !got.After(now) {
t.Error("expired window 的下次重置必须在 now 之后,不能是过去时间")
}
}
func TestIncrementUserPlatformQuotaUsage_GuardsAgainstEmpty(t *testing.T) {
fake := &fakeIncrCache{}
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
s := &BillingCacheService{
cache: fake,
cfg: cfg,
}
s.IncrementUserPlatformQuotaUsage(1, "", 0.5) // empty platform → noop
s.IncrementUserPlatformQuotaUsage(1, "openai", 0) // zero cost → noop
s.IncrementUserPlatformQuotaUsage(1, "openai", -0.1) // negative → noop
if len(fake.calls) != 0 {
t.Errorf("expected 0 calls (all guarded), got %d", len(fake.calls))
}
}
// ── C-NEW-2: 订阅模式豁免 user×platform quota 检查 ──────────────────────────
// 通过直接调用 checkUserPlatformQuotaEligibility 验证:
// 1. standard 模式下 limit=0 → 拦截
// 2. 订阅模式豁免通过 isSubscriptionMode 守卫体现 — 逻辑已在 CheckBillingEligibility 里加 !isSubscriptionMode 条件
// 此处用单元测试直接验证底层 checkUserPlatformQuotaEligibility 的行为quota 超限确实拦截),
// 而 subscription bypass 逻辑则在 CheckBillingEligibility 中通过条件判断保证,不绕过 sub eligibility 内部复杂依赖。
// fakeZeroQuotaCache 模拟 cache 命中且 daily limit=0quota 耗尽)。
type fakeZeroQuotaCache struct {
BillingCache
called bool
}
func (f *fakeZeroQuotaCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) {
f.called = true
daily := 0.0
entry := &UserPlatformQuotaCacheEntry{
DailyUsageUSD: 0,
DailyLimitUSD: &daily,
DailyWindowStart: func() *time.Time { t := time.Now().UTC(); return &t }(),
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
}
return entry, true, nil
}
func (f *fakeZeroQuotaCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error {
return nil
}
// SetUserPlatformQuotaCache 在 weekly/monthly window_start 为 nil 时,checkUserPlatform...
// 会触发"窗口过期 → SetCache 刷新"分支。fake 用 noop 避免 panic。
func (f *fakeZeroQuotaCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, _ *UserPlatformQuotaCacheEntry, _ time.Duration) error {
return nil
}
// GetSubscriptionCache 返回有效订阅active、未过期、usage 远低于 limit
// 用于支持 checkSubscriptionEligibility 通过,以便验证 quota 检查不被触发。
func (f *fakeZeroQuotaCache) GetSubscriptionCache(_ context.Context, _ int64, _ int64) (*SubscriptionCacheData, error) {
return &SubscriptionCacheData{
Status: SubscriptionStatusActive,
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
DailyUsage: 0,
WeeklyUsage: 0,
MonthlyUsage: 0,
}, nil
}
func (f *fakeZeroQuotaCache) GetUserBalanceCache(_ context.Context, _ int64) (float64, bool, error) {
return 100.0, true, nil
}
// TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero 验证:
// standard 模式下 limit=0 的 platform quota 确实会被拦截(守卫底层逻辑正确)。
func TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero(t *testing.T) {
fake := &fakeZeroQuotaCache{}
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
s := &BillingCacheService{
cache: fake,
cfg: cfg,
userPlatformQuotaRepo: &fakeQuotaRepo{},
}
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
t.Errorf("standard mode with limit=0 should return ErrUserPlatformDailyQuotaExhausted, got: %v", err)
}
if !fake.called {
t.Error("GetUserPlatformQuotaCache should have been called in standard mode")
}
}
// TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota 验证C-NEW-2
// 订阅模式用户不受 user×platform quota 拦截GetUserPlatformQuotaCache 不应被调用。
func TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota(t *testing.T) {
fake := &fakeZeroQuotaCache{} // GetUserPlatformQuotaCache 返回 limit=0若被调用则拦截
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
s := &BillingCacheService{
cache: fake,
cfg: cfg,
userPlatformQuotaRepo: &fakeQuotaRepo{},
}
subGroup := &Group{
ID: 10,
SubscriptionType: "subscription",
Status: "active",
// 无 DailyLimitUSD → checkSubscriptionEligibility 不会因超限失败
}
sub := &UserSubscription{Status: "active"}
user := &User{ID: 42}
err := s.CheckBillingEligibility(context.Background(), user, nil, subGroup, sub, "anthropic")
// 订阅模式下不应收到任何 user×platform quota 错误
if errors.Is(err, ErrUserPlatformDailyQuotaExhausted) ||
errors.Is(err, ErrUserPlatformWeeklyQuotaExhausted) ||
errors.Is(err, ErrUserPlatformMonthlyQuotaExhausted) {
t.Errorf("subscription mode should bypass user×platform quota, got: %v", err)
}
// GetUserPlatformQuotaCache 不应被调用
if fake.called {
t.Error("GetUserPlatformQuotaCache must NOT be called in subscription mode (C-NEW-2)")
}
}
// TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota 验证:
// 非订阅模式group=nil用户 platform quota 超限时被拦截quota cache 被查询。
func TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota(t *testing.T) {
called := &fakeZeroQuotaCache{}
cfg := &config.Config{}
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
s := &BillingCacheService{
cache: called,
cfg: cfg,
userPlatformQuotaRepo: &fakeQuotaRepo{},
}
err := s.checkUserPlatformQuotaEligibility(context.Background(), 99, "openai")
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
t.Errorf("non-subscription mode quota check should block, got: %v", err)
}
if !called.called {
t.Error("GetUserPlatformQuotaCache should be consulted in non-subscription mode")
}
}
// ── B-3: monthlyQuotaWindowExpired 30 天边界表驱动测试 ────────────────────────
// 覆盖 4 个必须场景:
// 1. 恰好 30 天 → expired
// 2. 30*24h - 1ns → not expired
// 3. 跨月末2024-02-28 → 2024-03-29T00:00:01Z→ expired
// 4. 跨年2024-12-15 → 2025-01-14T00:00:01Z→ expired
//
// repo 层 monthlyMaybeReset 不可导出,通过 service 层 monthlyQuotaWindowExpired 间接覆盖。
func TestMonthlyQuotaWindowExpired_BoundaryTable(t *testing.T) {
const thirtyDays = 30 * 24 * time.Hour
cases := []struct {
name string
start time.Time
now time.Time
expired bool
}{
{
name: "exactly 30 days → expired",
start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays),
expired: true,
},
{
name: "30d minus 1ns → not expired",
start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays - 1),
expired: false,
},
{
name: "cross month-end (Feb→Mar, 29d+1s) → expired",
start: time.Date(2024, 2, 28, 0, 0, 0, 0, time.UTC),
now: time.Date(2024, 3, 29, 0, 0, 1, 0, time.UTC),
expired: true,
},
{
name: "cross year boundary (Dec→Jan, 30d+1s) → expired",
start: time.Date(2024, 12, 15, 0, 0, 0, 0, time.UTC),
now: time.Date(2025, 1, 14, 0, 0, 1, 0, time.UTC),
expired: true,
},
}
for _, tc := range cases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
got := monthlyQuotaWindowExpired(&tc.start, tc.now)
if got != tc.expired {
t.Errorf("monthlyQuotaWindowExpired(start=%v, now=%v) = %v, want %v",
tc.start, tc.now, got, tc.expired)
}
})
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
@ -20,6 +21,32 @@ type APIKeyRateLimitCacheData struct {
Window7d int64 `json:"window_7d"`
}
// UserPlatformQuotaCacheEntry Redis hash 反序列化结果。
//
// SchemaVersion 用于向后兼容:
// - 0旧 entry无 SchemaVersion 字段)→ 视为 cache MISS强制 refresh
// - 1当前版本→ 包含 limits 和 window_start可免 DB 查询
//
// limit 字段为 nil 表示"无限额"DB 中对应列为 NULL
const UserPlatformQuotaCacheSchemaV1 = int64(1)
type UserPlatformQuotaCacheEntry struct {
DailyUsageUSD float64
WeeklyUsageUSD float64
MonthlyUsageUSD float64
Version int64
SchemaVersion int64
// 以下字段仅在 SchemaVersion >= 1 时有效
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
@ -39,6 +66,13 @@ type BillingCache interface {
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
// user × platform quota 缓存
GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error)
SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error
DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error
// IncrUserPlatformQuotaUsageCache 在缓存命中时累加用量缓存未命中key 不存在)静默返回 nil。
IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error
}
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致

View File

@ -1,6 +1,10 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/domain"
import (
"fmt"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
// Status constants
const (
@ -39,6 +43,26 @@ const (
PlatformAntigravity = domain.PlatformAntigravity
)
// AllowedQuotaPlatforms 是允许设置 user × platform quota 的平台列表(单一权威来源)。
// ent/schema/user_platform_quota.go 的 Validate 函数独立维护(构建期约束),
// 若新增平台需同步修改该 schema。
var AllowedQuotaPlatforms = []string{
PlatformAnthropic,
PlatformOpenAI,
PlatformGemini,
PlatformAntigravity,
}
// IsAllowedQuotaPlatform 报告 s 是否为合法的 quota platform 标识。
func IsAllowedQuotaPlatform(s string) bool {
for _, p := range AllowedQuotaPlatforms {
if p == s {
return true
}
}
return false
}
// Account type constants
const (
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
@ -424,5 +448,15 @@ const (
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
)
// SettingKeyDefaultPlatformQuotas —— 系统全局:每用户 × 平台日/周/月 USD 上限JSON
// 值为 map[platform]{daily,weekly,monthly}null/缺省 = 不限制0 = 禁用;>0 = USD 上限。
const SettingKeyDefaultPlatformQuotas = "default_platform_quotas"
// SettingKeyAuthSourcePlatformQuotas 返回某 auth source 的 platform quota JSON key。
// 形如 auth_source_default_{source}_platform_quotas
func SettingKeyAuthSourcePlatformQuotas(source string) string {
return fmt.Sprintf("auth_source_default_%s_platform_quotas", source)
}
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-"

View File

@ -0,0 +1,23 @@
//go:build unit
package service
import "testing"
// TestSettingKeyDefaultPlatformQuotas 验证新的系统层 JSON key 常量值正确。
func TestSettingKeyDefaultPlatformQuotas(t *testing.T) {
if SettingKeyDefaultPlatformQuotas != "default_platform_quotas" {
t.Errorf("SettingKeyDefaultPlatformQuotas = %q, want %q",
SettingKeyDefaultPlatformQuotas, "default_platform_quotas")
}
}
// TestSettingKeyAuthSourcePlatformQuotas 验证新的 auth-source JSON key 函数返回值正确。
func TestSettingKeyAuthSourcePlatformQuotas(t *testing.T) {
if got := SettingKeyAuthSourcePlatformQuotas("email"); got != "auth_source_default_email_platform_quotas" {
t.Fatalf("got %q, want %q", got, "auth_source_default_email_platform_quotas")
}
if got := SettingKeyAuthSourcePlatformQuotas("dingtalk"); got != "auth_source_default_dingtalk_platform_quotas" {
t.Fatalf("got %q, want %q", got, "auth_source_default_dingtalk_platform_quotas")
}
}

View File

@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil, // userPlatformQuotaRepo
)
}

View File

@ -95,6 +95,16 @@ var (
modelsListCacheHitTotal atomic.Int64
modelsListCacheMissTotal atomic.Int64
modelsListCacheStoreTotal atomic.Int64
// userPlatformQuotaDBIncrErrorTotal 统计 finalizePostUsageBilling 异步 goroutine
// 中 IncrementUsageWithReset 失败次数。Redis 已成功累加 + DB 写失败意味着
// Redis cache TTL 过期或被清后该笔 cost 会丢失(与实际消费偏差)。
// oncall 通过 GatewayUserPlatformQuotaIncrStats() 暴露给 ops 面板做阈值告警。
userPlatformQuotaDBIncrErrorTotal atomic.Int64
// userPlatformQuotaDBIncrLegacyErrorTotal 统计 legacy postUsageBilling
// applyUsageBilling 在 repo==nil 时 fallback路径下的失败次数
// 与 DB Incr 失败分开计数,便于区分"主路径暂时故障"vs"基础设施长期未配齐"。
userPlatformQuotaDBIncrLegacyErrorTotal atomic.Int64
)
func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) {
@ -117,6 +127,15 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
}
// GatewayUserPlatformQuotaIncrStats 返回 (mainPathErr, legacyPathErr)。
// mainPathErrfinalizePostUsageBilling 异步 goroutine 写 DB 失败累计次数;
// legacyPathErrpostUsageBilling fallback 路径写 DB 失败累计次数。
// ops 监控面板可以按"持续上升斜率"做告警阈值。
func GatewayUserPlatformQuotaIncrStats() (mainPathErr, legacyPathErr int64) {
return userPlatformQuotaDBIncrErrorTotal.Load(),
userPlatformQuotaDBIncrLegacyErrorTotal.Load()
}
func openAIStreamEventIsTerminal(data string) bool {
trimmed := strings.TrimSpace(data)
if trimmed == "" {
@ -575,6 +594,7 @@ type GatewayService struct {
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
balanceNotifyService *BalanceNotifyService
userPlatformQuotaRepo UserPlatformQuotaRepository
}
// NewGatewayService creates a new GatewayService
@ -605,41 +625,43 @@ func NewGatewayService(
channelService *ChannelService,
resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService,
userPlatformQuotaRepo UserPlatformQuotaRepository,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
balanceNotifyService: balanceNotifyService,
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
digestStore: digestStore,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
settingService: settingService,
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
balanceNotifyService: balanceNotifyService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
@ -7949,6 +7971,7 @@ type RecordUsageInput struct {
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
QuotaPlatform string // user×platform 配额计量平台handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
@ -7978,6 +8001,31 @@ type postUsageBillingParams struct {
IsSubscriptionBill bool
AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater
Platform string // 来自 APIKey 关联 Group 的平台标识
}
// PlatformFromAPIKey 从 APIKey 关联的 Group 推导 platform 名称。
// apiKey 为 nil 或 Group 信息缺失时返回空串(调用方据此 short-circuit quota 累加)。
// 导出供 handler 层调用。
func PlatformFromAPIKey(apiKey *APIKey) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return apiKey.Group.Platform
}
// QuotaPlatform 返回 user×platform 配额计量使用的平台标识。
// 强制平台路由(如 /antigravity优先按 ctx 中的 ForcePlatform 计量,否则回退到
// APIKey 关联 Group 的平台。
//
// 注意:必须用带 ForcePlatform 的请求 context 调用(如 handler 的 c.Request.Context())。
// 后扣运行在 worker 池的 background ctx 上没有 ForcePlatform因此后扣平台由 handler
// 预先算定、经 RecordUsageInput.QuotaPlatform 传入,不要在后扣链路用 worker ctx 调用本函数。
func QuotaPlatform(ctx context.Context, apiKey *APIKey) string {
if fp, ok := ctx.Value(ctxkey.ForcePlatform).(string); ok && fp != "" {
return fp
}
return PlatformFromAPIKey(apiKey)
}
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
@ -8036,6 +8084,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
}
// Platform quota DB-only 累加(与 finalizePostUsageBilling 行为对齐的兜底):
// - 仅对 standard余额模式生效订阅模式豁免
// - 直接走 DB不经 Redis Incr 队列legacy 路径在 repo==nil仓库未注入
// 时被触发,此时整套 billing repo 都不可用,没有"双队列"风险
// - 失败仅记 ALERT log + counter不阻断主扣费流程与正常路径一致
//
// 历史背景:原 legacy path 完全跳过此累加,导致部署中如果 repo 偶然为 nil
// 时用户消费可绕过 platform quota存在静默资金风险。
if !p.IsSubscriptionBill && p.Platform != "" && cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil {
if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(billingCtx, p.User.ID, p.Platform, cost.ActualCost, time.Now().UTC()); err != nil {
userPlatformQuotaDBIncrLegacyErrorTotal.Add(1)
logger.LegacyPrintf("service.gateway", "ALERT: legacy incr user platform quota DB failed user=%d platform=%s cost=%f: %v", p.User.ID, p.Platform, cost.ActualCost, err)
}
}
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
@ -8159,11 +8222,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
}
}
finalizePostUsageBilling(p, deps, result)
finalizePostUsageBilling(billingCtx, p, deps, result)
return true, nil
}
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
func finalizePostUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil {
return
}
@ -8182,6 +8245,32 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Platform quota 累加:仅在 standard余额模式生效订阅模式豁免
// Redis 同步写 + DB 异步持久化:
// - Redis 同步:确保下次 preflight 立即看到最新 usage,把 TOCTOU 超支窗口
// 限制在并发 in-flight 请求数量内(旧实现的异步入队会让超支无限累积直到 worker 处理)
// - DB 异步:在独立 goroutine 中走 detached context,失败用 ALERT log 触发 oncall 对账
if !p.IsSubscriptionBill && p.Platform != "" && p.Cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil {
deps.billingCacheService.IncrementUserPlatformQuotaUsage(p.User.ID, p.Platform, p.Cost.ActualCost)
dbCtx, dbCancel := detachUpstreamContext(ctx)
userID, platform, cost := p.User.ID, p.Platform, p.Cost.ActualCost
go func() {
defer func() {
if r := recover(); r != nil {
logger.LegacyPrintf("service.gateway", "ALERT: panic in user platform quota incr goroutine user=%d platform=%s: %v", userID, platform, r)
}
}()
defer dbCancel()
if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(dbCtx, userID, platform, cost, time.Now().UTC()); err != nil {
// 失败计数器:暴露给 GatewayUserPlatformQuotaIncrStats(),由 ops 面板做斜率告警。
userPlatformQuotaDBIncrErrorTotal.Add(1)
// ALERT 级别:DB 持久化失败意味着 Redis cache 失效后该笔 cost 永久丢失,
// 用户配额视图与实际消费会偏差,oncall 需要据此对账或人工补录。
logger.LegacyPrintf("service.gateway", "ALERT: incr user platform quota DB failed user=%d platform=%s cost=%f: %v", userID, platform, cost, err)
}
}()
}
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go notifyBalanceLow(p, deps, result)
@ -8287,22 +8376,24 @@ func detachUpstreamContext(ctx context.Context) (context.Context, context.Cancel
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
balanceNotifyService *BalanceNotifyService
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
balanceNotifyService *BalanceNotifyService
userPlatformQuotaRepo UserPlatformQuotaRepository
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
userPlatformQuotaRepo: s.userPlatformQuotaRepo,
}
}
@ -8360,6 +8451,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
QuotaPlatform: input.QuotaPlatform,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
EnableClaudePath: true,
@ -8382,6 +8474,7 @@ type RecordUsageLongContextInput struct {
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
QuotaPlatform string // user×platform 配额计量平台handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
}
@ -8401,6 +8494,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
RequestPayloadHash: input.RequestPayloadHash,
ForceCacheBilling: input.ForceCacheBilling,
APIKeyService: input.APIKeyService,
QuotaPlatform: input.QuotaPlatform,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
LongContextThreshold: input.LongContextThreshold,
@ -8422,6 +8516,7 @@ type recordUsageCoreInput struct {
RequestPayloadHash string
ForceCacheBilling bool
APIKeyService APIKeyQuotaUpdater
QuotaPlatform string
ChannelUsageFields
}
@ -8519,6 +8614,13 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
return nil
}
// 配额平台由 handler 在请求 ctx 内经 QuotaPlatform() 算定并通过 input 传入;
// 后扣运行在 worker 池的 background ctx 上,无法再从 ctx 取 ForcePlatform。
// 缺省(未设置)时回退到分组平台,保持对其它调用方的兼容。
quotaPlatform := input.QuotaPlatform
if quotaPlatform == "" {
quotaPlatform = PlatformFromAPIKey(apiKey)
}
requestID := usageLog.RequestID
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
@ -8530,6 +8632,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
Platform: quotaPlatform,
}, s.billingDeps(), s.usageBillingRepo)
if billingErr != nil {

View File

@ -155,6 +155,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
nil, // userPlatformQuotaRepo
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,

View File

@ -343,6 +343,7 @@ type OpenAIGatewayService struct {
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
settingService *SettingService
userPlatformQuotaRepo UserPlatformQuotaRepository
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@ -387,6 +388,7 @@ func NewOpenAIGatewayService(
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
settingService *SettingService,
userPlatformQuotaRepo UserPlatformQuotaRepository,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@ -418,6 +420,7 @@ func NewOpenAIGatewayService(
channelService: channelService,
balanceNotifyService: balanceNotifyService,
settingService: settingService,
userPlatformQuotaRepo: userPlatformQuotaRepo,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@ -523,12 +526,13 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
userPlatformQuotaRepo: s.userPlatformQuotaRepo,
}
}
@ -5634,6 +5638,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
Platform: PlatformFromAPIKey(apiKey),
}, s.billingDeps(), s.usageBillingRepo)
return err
}()

View File

@ -619,6 +619,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil, // userPlatformQuotaRepo
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)

View File

@ -0,0 +1,81 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
func TestPlatformFromAPIKey_NilSafe(t *testing.T) {
if got := PlatformFromAPIKey(nil); got != "" {
t.Errorf("nil APIKey should yield empty string, got %q", got)
}
}
func TestPlatformFromAPIKey_NilGroup(t *testing.T) {
k := &APIKey{Group: nil}
if got := PlatformFromAPIKey(k); got != "" {
t.Errorf("APIKey with nil Group should yield empty string, got %q", got)
}
}
func TestPlatformFromAPIKey_DerivesFromGroup(t *testing.T) {
tests := []struct {
name string
platform string
}{
{"anthropic", "anthropic"},
{"openai", "openai"},
{"gemini", "gemini"},
{"antigravity", "antigravity"},
{"empty", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &APIKey{
Group: &Group{Platform: tt.platform},
}
got := PlatformFromAPIKey(k)
if got != tt.platform {
t.Errorf("PlatformFromAPIKey(%q) = %q, want %q", tt.platform, got, tt.platform)
}
})
}
}
// TestQuotaPlatform 锁定配额计量口径ForcePlatform 路由(如 /antigravity按 ForcePlatform 计,
// 否则回退到 Group 平台。preflight 与 post-billing 共用此口径,保证一致。
func TestQuotaPlatform(t *testing.T) {
apiKey := &APIKey{Group: &Group{Platform: PlatformAnthropic}}
t.Run("no force platform falls back to group platform", func(t *testing.T) {
if got := QuotaPlatform(context.Background(), apiKey); got != PlatformAnthropic {
t.Errorf("QuotaPlatform without force = %q, want %q", got, PlatformAnthropic)
}
})
t.Run("force platform overrides group platform", func(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity)
if got := QuotaPlatform(ctx, apiKey); got != PlatformAntigravity {
t.Errorf("QuotaPlatform with force = %q, want %q", got, PlatformAntigravity)
}
})
t.Run("empty force platform falls back to group platform", func(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, "")
if got := QuotaPlatform(ctx, apiKey); got != PlatformAnthropic {
t.Errorf("QuotaPlatform with empty force = %q, want %q", got, PlatformAnthropic)
}
})
t.Run("nil api key with force platform returns force platform", func(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity)
if got := QuotaPlatform(ctx, nil); got != PlatformAntigravity {
t.Errorf("QuotaPlatform(nil) with force = %q, want %q", got, PlatformAntigravity)
}
})
}

View File

@ -165,12 +165,20 @@ type SettingService struct {
openAICodexUASF singleflight.Group
}
// DefaultPlatformQuotaSetting 单 platform 三档限额nil = 沿用上层0 = 显式禁用;>0 = 上限)
type DefaultPlatformQuotaSetting struct {
DailyLimitUSD *float64 `json:"daily"`
WeeklyLimitUSD *float64 `json:"weekly"`
MonthlyLimitUSD *float64 `json:"monthly"`
}
type ProviderDefaultGrantSettings struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
GrantOnSignup bool
GrantOnFirstBind bool
PlatformQuotas map[string]*DefaultPlatformQuotaSetting // key = platform name
}
type AuthSourceDefaultSettings struct {
@ -185,62 +193,80 @@ type AuthSourceDefaultSettings struct {
}
type authSourceDefaultKeySet struct {
// source 是 auth source 标识(如 "email"、"github"),仅用于 parse 时
// slog.Warn 诊断输出,不再参与 key 拼接platformQuotas 字段已存完整 key
source string
balance string
concurrency string
subscriptions string
grantOnSignup string
grantOnFirstBind string
platformQuotas string // SettingKeyAuthSourcePlatformQuotas(source)
}
var (
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "email",
balance: SettingKeyAuthSourceDefaultEmailBalance,
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("email"),
}
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "linuxdo",
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("linuxdo"),
}
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "oidc",
balance: SettingKeyAuthSourceDefaultOIDCBalance,
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("oidc"),
}
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "wechat",
balance: SettingKeyAuthSourceDefaultWeChatBalance,
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("wechat"),
}
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "github",
balance: SettingKeyAuthSourceDefaultGitHubBalance,
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("github"),
}
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "google",
balance: SettingKeyAuthSourceDefaultGoogleBalance,
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("google"),
}
dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{
source: "dingtalk",
balance: SettingKeyAuthSourceDefaultDingTalkBalance,
concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency,
subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
platformQuotas: SettingKeyAuthSourcePlatformQuotas("dingtalk"),
}
)
@ -1804,9 +1830,41 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
// 系统全局 platform quota整体替换语义null/缺省 = 不限制)。
if settings.DefaultPlatformQuotas != nil {
if err := validateDefaultPlatformQuotaMap(settings.DefaultPlatformQuotas); err != nil {
return nil, err
}
blob, err := json.Marshal(settings.DefaultPlatformQuotas)
if err != nil {
return nil, fmt.Errorf("marshal default platform quotas: %w", err)
}
updates[SettingKeyDefaultPlatformQuotas] = string(blob)
}
return updates, nil
}
// validateDefaultPlatformQuotaMap 校验 platform quota map 的合法性:
// 平台名须在 AllowedQuotaPlatforms 白名单内,每个非 nil 上限须 finite 且 >= 0。
// 系统层和 auth-source 层共用此 helper。
func validateDefaultPlatformQuotaMap(m map[string]*DefaultPlatformQuotaSetting) error {
for platform, pq := range m {
if !IsAllowedQuotaPlatform(platform) {
return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", fmt.Sprintf("unknown platform %q", platform))
}
if pq == nil {
continue
}
for _, v := range []*float64{pq.DailyLimitUSD, pq.WeeklyLimitUSD, pq.MonthlyLimitUSD} {
if v != nil && (*v < 0 || math.IsNaN(*v) || math.IsInf(*v, 0)) {
return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", "platform quota limit must be a finite non-negative number")
}
}
}
return nil
}
func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
if settings == nil {
return nil, nil
@ -1826,6 +1884,26 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
}
}
// 校验各 auth source 的 platform quota map改动 C对等系统层校验
for _, pgs := range []struct {
name string
pq map[string]*DefaultPlatformQuotaSetting
}{
{"email", settings.Email.PlatformQuotas},
{"linuxdo", settings.LinuxDo.PlatformQuotas},
{"oidc", settings.OIDC.PlatformQuotas},
{"wechat", settings.WeChat.PlatformQuotas},
{"github", settings.GitHub.PlatformQuotas},
{"google", settings.Google.PlatformQuotas},
{"dingtalk", settings.DingTalk.PlatformQuotas},
} {
if pgs.pq != nil {
if err := validateDefaultPlatformQuotaMap(pgs.pq); err != nil {
return nil, err
}
}
}
updates := make(map[string]string, 36)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
@ -2386,6 +2464,13 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
SettingKeyAuthSourceDefaultDingTalkSubscriptions,
SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
SettingKeyAuthSourcePlatformQuotas("email"),
SettingKeyAuthSourcePlatformQuotas("linuxdo"),
SettingKeyAuthSourcePlatformQuotas("oidc"),
SettingKeyAuthSourcePlatformQuotas("wechat"),
SettingKeyAuthSourcePlatformQuotas("github"),
SettingKeyAuthSourcePlatformQuotas("google"),
SettingKeyAuthSourcePlatformQuotas("dingtalk"),
SettingKeyForceEmailOnThirdPartySignup,
}
@ -3179,6 +3264,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
}
// 系统层默认 platform quota修复 Bug BparseSettings 不填充导致回显恒为 nil
if raw := settings[SettingKeyDefaultPlatformQuotas]; raw != "" {
parsed := map[string]*DefaultPlatformQuotaSetting{}
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
slog.Warn("[Setting] parseSettings: unmarshal default_platform_quotas failed", "error", err)
} else {
result.DefaultPlatformQuotas = parsed
}
}
return result
}
@ -3271,6 +3366,15 @@ func parseProviderDefaultGrantSettings(settings map[string]string, keys authSour
result.GrantOnFirstBind = raw == "true"
}
if raw := settings[keys.platformQuotas]; raw != "" {
parsed := map[string]*DefaultPlatformQuotaSetting{}
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
slog.Warn("[Setting] parseProviderDefaultGrantSettings: unmarshal auth source platform quotas failed", "source", keys.source, "error", err)
} else {
result.PlatformQuotas = parsed
}
}
return result
}
@ -3289,6 +3393,17 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource
updates[keys.subscriptions] = string(raw)
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
// auth source platform quota整体替换语义。
// nil = 请求未携带该字段,跳过写入以保留既有配置(与系统层 buildSystemSettingsUpdates 的
// DefaultPlatformQuotas nil 守卫一致);非 nil含空 map才整体替换。二者语义不可混同。
if keys.platformQuotas != "" && settings.PlatformQuotas != nil {
blob, err := json.Marshal(settings.PlatformQuotas)
if err != nil {
blob = []byte("{}")
}
updates[keys.platformQuotas] = string(blob)
}
}
func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
@ -4493,3 +4608,63 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
// GetDefaultPlatformQuotas 读取系统全局 platform quota JSON key返回 4 platform x 3 window 的设置。
// 永远返回包含全部 4 platform key 的 map值可能为零值/nil 字段,表示"上层未配置 = 不限制")。
//
// 使用单个 JSON keydefault_platform_quotas一次 DB roundtrip消除旧 12-KV 格式的 N+1 问题。
// 容错语义:取值失败或 unmarshal 失败 → 返回补齐 4 key 的空 mapfail-open注册不被阻断
func (s *SettingService) GetDefaultPlatformQuotas(ctx context.Context) (map[string]*DefaultPlatformQuotaSetting, error) {
out := map[string]*DefaultPlatformQuotaSetting{
"anthropic": {},
"openai": {},
"gemini": {},
"antigravity": {},
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultPlatformQuotas)
if err != nil || raw == "" {
return out, nil // 无配置 = 全部不限制
}
parsed := map[string]*DefaultPlatformQuotaSetting{}
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
slog.Warn("[Setting] unmarshal default_platform_quotas failed (fail-open)", "error", err)
return out, nil
}
for _, platform := range AllowedQuotaPlatforms {
if v := parsed[platform]; v != nil {
out[platform] = v
}
}
return out, nil // 补齐 4 platform key保持与旧实现一致的下游契约
}
// GetAuthSourcePlatformQuotas 读取指定 auth source 的 platform quota 覆盖仅返回有配置的平台override 语义)。
func (s *SettingService) GetAuthSourcePlatformQuotas(ctx context.Context, source string) map[string]*DefaultPlatformQuotaSetting {
out := map[string]*DefaultPlatformQuotaSetting{}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAuthSourcePlatformQuotas(source))
if err != nil || raw == "" {
return out // 无 override
}
if err := json.Unmarshal([]byte(raw), &out); err != nil {
slog.Warn("[Setting] unmarshal auth source platform quotas failed (fail-open)", "source", source, "error", err)
return map[string]*DefaultPlatformQuotaSetting{}
}
return out // 仅含已配置平台,保持 override 语义
}
// mergePlatformQuotaDefaults 按字段级 patchsrc 中非 nil 字段覆盖 dst。
// 区分 nil"未配置",保留 dstvs &0.0"显式禁用",覆盖 dst 为 0
func mergePlatformQuotaDefaults(dst, src *DefaultPlatformQuotaSetting) {
if src == nil || dst == nil {
return
}
if src.DailyLimitUSD != nil {
dst.DailyLimitUSD = src.DailyLimitUSD
}
if src.WeeklyLimitUSD != nil {
dst.WeeklyLimitUSD = src.WeeklyLimitUSD
}
if src.MonthlyLimitUSD != nil {
dst.MonthlyLimitUSD = src.MonthlyLimitUSD
}
}

Some files were not shown because too many files have changed in this diff Show More