diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 9386678d..fc8a00c4 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e39a645a..465f5e25 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyRepository := repository.NewAPIKeyRepository(client, db) userRPMCache := repository.NewUserRPMCache(redisClient) userGroupRateRepository := repository.NewUserGroupRateRepository(db) - billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) + userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client) + serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository) + billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) @@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) affiliateRepository := repository.NewAffiliateRepository(client, db) affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService) - authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService) @@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService) - userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService) - adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) rpmCache := repository.NewRPMCache(redisClient) groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) @@ -190,7 +192,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsRepository := repository.NewOpsRepository(db) identityService := service.NewIdentityService(identityCache) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, serviceUserPlatformQuotaRepository) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 5ccd67fb..a44b2d5c 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) diff --git a/backend/ent/client.go b/backend/ent/client.go index df20ddfa..06b53c78 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -48,6 +48,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" stdsql "database/sql" @@ -124,6 +125,8 @@ type Client struct { UserAttributeDefinition *UserAttributeDefinitionClient // UserAttributeValue is the client for interacting with the UserAttributeValue builders. UserAttributeValue *UserAttributeValueClient + // UserPlatformQuota is the client for interacting with the UserPlatformQuota builders. + UserPlatformQuota *UserPlatformQuotaClient // UserSubscription is the client for interacting with the UserSubscription builders. UserSubscription *UserSubscriptionClient } @@ -170,6 +173,7 @@ func (c *Client) init() { c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config) c.UserAttributeValue = NewUserAttributeValueClient(c.config) + c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config) c.UserSubscription = NewUserSubscriptionClient(c.config) } @@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), UserAttributeValue: NewUserAttributeValueClient(cfg), + UserPlatformQuota: NewUserPlatformQuotaClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), UserAttributeValue: NewUserAttributeValueClient(cfg), + UserPlatformQuota: NewUserPlatformQuotaClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) { c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.UserPlatformQuota, c.UserSubscription, } { n.Use(hooks...) } @@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) { c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.UserPlatformQuota, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.UserAttributeDefinition.mutate(ctx, m) case *UserAttributeValueMutation: return c.UserAttributeValue.mutate(ctx, m) + case *UserPlatformQuotaMutation: + return c.UserPlatformQuota.mutate(ctx, m) case *UserSubscriptionMutation: return c.UserSubscription.mutate(ctx, m) default: @@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery return query } +// QueryPlatformQuotas queries the platform_quotas edge of a User. +func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery { + query := (&UserPlatformQuotaClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV } } +// UserPlatformQuotaClient is a client for the UserPlatformQuota schema. +type UserPlatformQuotaClient struct { + config +} + +// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config. +func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient { + return &UserPlatformQuotaClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`. +func (c *UserPlatformQuotaClient) Use(hooks ...Hook) { + c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`. +func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) { + c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...) +} + +// Create returns a builder for creating a UserPlatformQuota entity. +func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate { + mutation := newUserPlatformQuotaMutation(c.config, OpCreate) + return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities. +func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk { + return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UserPlatformQuotaCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdate) + return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m)) + return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne { + mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id)) + return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete { + mutation := newUserPlatformQuotaMutation(c.config, OpDelete) + return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne { + builder := c.Delete().Where(userplatformquota.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UserPlatformQuotaDeleteOne{builder} +} + +// Query returns a query builder for UserPlatformQuota. +func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery { + return &UserPlatformQuotaQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUserPlatformQuota}, + inters: c.Interceptors(), + } +} + +// Get returns a UserPlatformQuota entity by its id. +func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) { + return c.Query().Where(userplatformquota.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a UserPlatformQuota. +func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *UserPlatformQuotaClient) Hooks() []Hook { + hooks := c.hooks.UserPlatformQuota + return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...) +} + +// Interceptors returns the client interceptors. +func (c *UserPlatformQuotaClient) Interceptors() []Interceptor { + inters := c.inters.UserPlatformQuota + return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...) +} + +func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op()) + } +} + // UserSubscriptionClient is a client for the UserSubscription schema. type UserSubscriptionClient struct { config @@ -6025,7 +6200,8 @@ type ( PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + UserAttributeDefinition, UserAttributeValue, UserPlatformQuota, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, @@ -6035,7 +6211,8 @@ type ( PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + UserAttributeDefinition, UserAttributeValue, UserPlatformQuota, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index c9fcc314..33d36e70 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -45,6 +45,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -139,6 +140,7 @@ func checkColumn(t, c string) error { userallowedgroup.Table: userallowedgroup.ValidColumn, userattributedefinition.Table: userattributedefinition.ValidColumn, userattributevalue.Table: userattributevalue.ValidColumn, + userplatformquota.Table: userplatformquota.ValidColumn, usersubscription.Table: usersubscription.ValidColumn, }) }) diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 414eba24..71bfd3b8 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m) } +// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary +// function as UserPlatformQuota mutator. +type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m) +} + // The UserSubscriptionFunc type is an adapter to allow the use of ordinary // function as UserSubscription mutator. type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 95b68e09..5d86e25b 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -42,6 +42,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q) } +// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier. +type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UserPlatformQuotaQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q) +} + +// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UserPlatformQuotaQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q) +} + // The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier. type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error) @@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil case *ent.UserAttributeValueQuery: return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil + case *ent.UserPlatformQuotaQuery: + return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil case *ent.UserSubscriptionQuery: return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil default: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index b1731a35..447f71ef 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -1612,6 +1612,53 @@ var ( }, }, } + // UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table. + UserPlatformQuotasColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "platform", Type: field.TypeString, Size: 32}, + {Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table. + UserPlatformQuotasTable = &schema.Table{ + Name: "user_platform_quotas", + Columns: UserPlatformQuotasColumns, + PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "user_platform_quotas_users_platform_quotas", + Columns: []*schema.Column{UserPlatformQuotasColumns[14]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "userplatformquota_user_id_platform", + Unique: true, + Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]}, + Annotation: &entsql.IndexAnnotation{ + Where: "deleted_at IS NULL", + }, + }, + { + Name: "userplatformquota_user_id", + Unique: false, + Columns: []*schema.Column{UserPlatformQuotasColumns[14]}, + }, + }, + } // UserSubscriptionsColumns holds the columns for the "user_subscriptions" table. UserSubscriptionsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1736,6 +1783,7 @@ var ( UserAllowedGroupsTable, UserAttributeDefinitionsTable, UserAttributeValuesTable, + UserPlatformQuotasTable, UserSubscriptionsTable, } ) @@ -1869,6 +1917,10 @@ func init() { UserAttributeValuesTable.Annotation = &entsql.Annotation{ Table: "user_attribute_values", } + UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable + UserPlatformQuotasTable.Annotation = &entsql.Annotation{ + Table: "user_platform_quotas", + } UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index af0edc68..2e8fa7f4 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -46,6 +46,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -92,6 +93,7 @@ const ( TypeUserAllowedGroup = "UserAllowedGroup" TypeUserAttributeDefinition = "UserAttributeDefinition" TypeUserAttributeValue = "UserAttributeValue" + TypeUserPlatformQuota = "UserPlatformQuota" TypeUserSubscription = "UserSubscription" ) @@ -38160,6 +38162,9 @@ type UserMutation struct { pending_auth_sessions map[int64]struct{} removedpending_auth_sessions map[int64]struct{} clearedpending_auth_sessions bool + platform_quotas map[int64]struct{} + removedplatform_quotas map[int64]struct{} + clearedplatform_quotas bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -39918,6 +39923,60 @@ func (m *UserMutation) ResetPendingAuthSessions() { m.removedpending_auth_sessions = nil } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by ids. +func (m *UserMutation) AddPlatformQuotaIDs(ids ...int64) { + if m.platform_quotas == nil { + m.platform_quotas = make(map[int64]struct{}) + } + for i := range ids { + m.platform_quotas[ids[i]] = struct{}{} + } +} + +// ClearPlatformQuotas clears the "platform_quotas" edge to the UserPlatformQuota entity. +func (m *UserMutation) ClearPlatformQuotas() { + m.clearedplatform_quotas = true +} + +// PlatformQuotasCleared reports if the "platform_quotas" edge to the UserPlatformQuota entity was cleared. +func (m *UserMutation) PlatformQuotasCleared() bool { + return m.clearedplatform_quotas +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (m *UserMutation) RemovePlatformQuotaIDs(ids ...int64) { + if m.removedplatform_quotas == nil { + m.removedplatform_quotas = make(map[int64]struct{}) + } + for i := range ids { + delete(m.platform_quotas, ids[i]) + m.removedplatform_quotas[ids[i]] = struct{}{} + } +} + +// RemovedPlatformQuotas returns the removed IDs of the "platform_quotas" edge to the UserPlatformQuota entity. +func (m *UserMutation) RemovedPlatformQuotasIDs() (ids []int64) { + for id := range m.removedplatform_quotas { + ids = append(ids, id) + } + return +} + +// PlatformQuotasIDs returns the "platform_quotas" edge IDs in the mutation. +func (m *UserMutation) PlatformQuotasIDs() (ids []int64) { + for id := range m.platform_quotas { + ids = append(ids, id) + } + return +} + +// ResetPlatformQuotas resets all changes to the "platform_quotas" edge. +func (m *UserMutation) ResetPlatformQuotas() { + m.platform_quotas = nil + m.clearedplatform_quotas = false + m.removedplatform_quotas = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -40527,7 +40586,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -40564,6 +40623,9 @@ func (m *UserMutation) AddedEdges() []string { if m.pending_auth_sessions != nil { edges = append(edges, user.EdgePendingAuthSessions) } + if m.platform_quotas != nil { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40643,13 +40705,19 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgePlatformQuotas: + ids := make([]ent.Value, 0, len(m.platform_quotas)) + for id := range m.platform_quotas { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -40686,6 +40754,9 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpending_auth_sessions != nil { edges = append(edges, user.EdgePendingAuthSessions) } + if m.removedplatform_quotas != nil { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40765,13 +40836,19 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgePlatformQuotas: + ids := make([]ent.Value, 0, len(m.removedplatform_quotas)) + for id := range m.removedplatform_quotas { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 12) + edges := make([]string, 0, 13) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -40808,6 +40885,9 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpending_auth_sessions { edges = append(edges, user.EdgePendingAuthSessions) } + if m.clearedplatform_quotas { + edges = append(edges, user.EdgePlatformQuotas) + } return edges } @@ -40839,6 +40919,8 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedauth_identities case user.EdgePendingAuthSessions: return m.clearedpending_auth_sessions + case user.EdgePlatformQuotas: + return m.clearedplatform_quotas } return false } @@ -40891,6 +40973,9 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePendingAuthSessions: m.ResetPendingAuthSessions() return nil + case user.EdgePlatformQuotas: + m.ResetPlatformQuotas() + return nil } return fmt.Errorf("unknown User edge %s", name) } @@ -43111,6 +43196,1428 @@ func (m *UserAttributeValueMutation) ResetEdge(name string) error { return fmt.Errorf("unknown UserAttributeValue edge %s", name) } +// UserPlatformQuotaMutation represents an operation that mutates the UserPlatformQuota nodes in the graph. +type UserPlatformQuotaMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + platform *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + daily_usage_usd *float64 + adddaily_usage_usd *float64 + weekly_usage_usd *float64 + addweekly_usage_usd *float64 + monthly_usage_usd *float64 + addmonthly_usage_usd *float64 + daily_window_start *time.Time + weekly_window_start *time.Time + monthly_window_start *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + done bool + oldValue func(context.Context) (*UserPlatformQuota, error) + predicates []predicate.UserPlatformQuota +} + +var _ ent.Mutation = (*UserPlatformQuotaMutation)(nil) + +// userplatformquotaOption allows management of the mutation configuration using functional options. +type userplatformquotaOption func(*UserPlatformQuotaMutation) + +// newUserPlatformQuotaMutation creates new mutation for the UserPlatformQuota entity. +func newUserPlatformQuotaMutation(c config, op Op, opts ...userplatformquotaOption) *UserPlatformQuotaMutation { + m := &UserPlatformQuotaMutation{ + config: c, + op: op, + typ: TypeUserPlatformQuota, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUserPlatformQuotaID sets the ID field of the mutation. +func withUserPlatformQuotaID(id int64) userplatformquotaOption { + return func(m *UserPlatformQuotaMutation) { + var ( + err error + once sync.Once + value *UserPlatformQuota + ) + m.oldValue = func(ctx context.Context) (*UserPlatformQuota, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UserPlatformQuota.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUserPlatformQuota sets the old UserPlatformQuota of the mutation. +func withUserPlatformQuota(node *UserPlatformQuota) userplatformquotaOption { + return func(m *UserPlatformQuotaMutation) { + m.oldValue = func(context.Context) (*UserPlatformQuota, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UserPlatformQuotaMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UserPlatformQuotaMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UserPlatformQuotaMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UserPlatformQuotaMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UserPlatformQuota.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *UserPlatformQuotaMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UserPlatformQuotaMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UserPlatformQuotaMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *UserPlatformQuotaMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *UserPlatformQuotaMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *UserPlatformQuotaMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetDeletedAt sets the "deleted_at" field. +func (m *UserPlatformQuotaMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserPlatformQuotaMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserPlatformQuotaMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[userplatformquota.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserPlatformQuotaMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, userplatformquota.FieldDeletedAt) +} + +// SetUserID sets the "user_id" field. +func (m *UserPlatformQuotaMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *UserPlatformQuotaMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *UserPlatformQuotaMutation) ResetUserID() { + m.user = nil +} + +// SetPlatform sets the "platform" field. +func (m *UserPlatformQuotaMutation) SetPlatform(s string) { + m.platform = &s +} + +// Platform returns the value of the "platform" field in the mutation. +func (m *UserPlatformQuotaMutation) Platform() (r string, exists bool) { + v := m.platform + if v == nil { + return + } + return *v, true +} + +// OldPlatform returns the old "platform" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldPlatform(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatform requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + } + return oldValue.Platform, nil +} + +// ResetPlatform resets all changes to the "platform" field. +func (m *UserPlatformQuotaMutation) ResetPlatform() { + m.platform = nil +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetDailyLimitUsd(f float64) { + m.daily_limit_usd = &f + m.adddaily_limit_usd = nil +} + +// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyLimitUsd() (r float64, exists bool) { + v := m.daily_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + } + return oldValue.DailyLimitUsd, nil +} + +// AddDailyLimitUsd adds f to the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddDailyLimitUsd(f float64) { + if m.adddaily_limit_usd != nil { + *m.adddaily_limit_usd += f + } else { + m.adddaily_limit_usd = &f + } +} + +// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedDailyLimitUsd() (r float64, exists bool) { + v := m.adddaily_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + m.clearedFields[userplatformquota.FieldDailyLimitUsd] = struct{}{} +} + +// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DailyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDailyLimitUsd] + return ok +} + +// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldDailyLimitUsd) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetWeeklyLimitUsd(f float64) { + m.weekly_limit_usd = &f + m.addweekly_limit_usd = nil +} + +// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyLimitUsd() (r float64, exists bool) { + v := m.weekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) + } + return oldValue.WeeklyLimitUsd, nil +} + +// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddWeeklyLimitUsd(f float64) { + if m.addweekly_limit_usd != nil { + *m.addweekly_limit_usd += f + } else { + m.addweekly_limit_usd = &f + } +} + +// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { + v := m.addweekly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + m.clearedFields[userplatformquota.FieldWeeklyLimitUsd] = struct{}{} +} + +// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) WeeklyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldWeeklyLimitUsd] + return ok +} + +// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldWeeklyLimitUsd) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) SetMonthlyLimitUsd(f float64) { + m.monthly_limit_usd = &f + m.addmonthly_limit_usd = nil +} + +// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyLimitUsd() (r float64, exists bool) { + v := m.monthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + } + return oldValue.MonthlyLimitUsd, nil +} + +// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) AddMonthlyLimitUsd(f float64) { + if m.addmonthly_limit_usd != nil { + *m.addmonthly_limit_usd += f + } else { + m.addmonthly_limit_usd = &f + } +} + +// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { + v := m.addmonthly_limit_usd + if v == nil { + return + } + return *v, true +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ClearMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + m.clearedFields[userplatformquota.FieldMonthlyLimitUsd] = struct{}{} +} + +// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) MonthlyLimitUsdCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldMonthlyLimitUsd] + return ok +} + +// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + delete(m.clearedFields, userplatformquota.FieldMonthlyLimitUsd) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetDailyUsageUsd(f float64) { + m.daily_usage_usd = &f + m.adddaily_usage_usd = nil +} + +// DailyUsageUsd returns the value of the "daily_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyUsageUsd() (r float64, exists bool) { + v := m.daily_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldDailyUsageUsd returns the old "daily_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyUsageUsd: %w", err) + } + return oldValue.DailyUsageUsd, nil +} + +// AddDailyUsageUsd adds f to the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddDailyUsageUsd(f float64) { + if m.adddaily_usage_usd != nil { + *m.adddaily_usage_usd += f + } else { + m.adddaily_usage_usd = &f + } +} + +// AddedDailyUsageUsd returns the value that was added to the "daily_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedDailyUsageUsd() (r float64, exists bool) { + v := m.adddaily_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetDailyUsageUsd resets all changes to the "daily_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetDailyUsageUsd() { + m.daily_usage_usd = nil + m.adddaily_usage_usd = nil +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetWeeklyUsageUsd(f float64) { + m.weekly_usage_usd = &f + m.addweekly_usage_usd = nil +} + +// WeeklyUsageUsd returns the value of the "weekly_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyUsageUsd() (r float64, exists bool) { + v := m.weekly_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldWeeklyUsageUsd returns the old "weekly_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyUsageUsd: %w", err) + } + return oldValue.WeeklyUsageUsd, nil +} + +// AddWeeklyUsageUsd adds f to the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddWeeklyUsageUsd(f float64) { + if m.addweekly_usage_usd != nil { + *m.addweekly_usage_usd += f + } else { + m.addweekly_usage_usd = &f + } +} + +// AddedWeeklyUsageUsd returns the value that was added to the "weekly_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedWeeklyUsageUsd() (r float64, exists bool) { + v := m.addweekly_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetWeeklyUsageUsd resets all changes to the "weekly_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyUsageUsd() { + m.weekly_usage_usd = nil + m.addweekly_usage_usd = nil +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) SetMonthlyUsageUsd(f float64) { + m.monthly_usage_usd = &f + m.addmonthly_usage_usd = nil +} + +// MonthlyUsageUsd returns the value of the "monthly_usage_usd" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyUsageUsd() (r float64, exists bool) { + v := m.monthly_usage_usd + if v == nil { + return + } + return *v, true +} + +// OldMonthlyUsageUsd returns the old "monthly_usage_usd" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyUsageUsd(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyUsageUsd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyUsageUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyUsageUsd: %w", err) + } + return oldValue.MonthlyUsageUsd, nil +} + +// AddMonthlyUsageUsd adds f to the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) AddMonthlyUsageUsd(f float64) { + if m.addmonthly_usage_usd != nil { + *m.addmonthly_usage_usd += f + } else { + m.addmonthly_usage_usd = &f + } +} + +// AddedMonthlyUsageUsd returns the value that was added to the "monthly_usage_usd" field in this mutation. +func (m *UserPlatformQuotaMutation) AddedMonthlyUsageUsd() (r float64, exists bool) { + v := m.addmonthly_usage_usd + if v == nil { + return + } + return *v, true +} + +// ResetMonthlyUsageUsd resets all changes to the "monthly_usage_usd" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyUsageUsd() { + m.monthly_usage_usd = nil + m.addmonthly_usage_usd = nil +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) SetDailyWindowStart(t time.Time) { + m.daily_window_start = &t +} + +// DailyWindowStart returns the value of the "daily_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) DailyWindowStart() (r time.Time, exists bool) { + v := m.daily_window_start + if v == nil { + return + } + return *v, true +} + +// OldDailyWindowStart returns the old "daily_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldDailyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyWindowStart: %w", err) + } + return oldValue.DailyWindowStart, nil +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) ClearDailyWindowStart() { + m.daily_window_start = nil + m.clearedFields[userplatformquota.FieldDailyWindowStart] = struct{}{} +} + +// DailyWindowStartCleared returns if the "daily_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) DailyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldDailyWindowStart] + return ok +} + +// ResetDailyWindowStart resets all changes to the "daily_window_start" field. +func (m *UserPlatformQuotaMutation) ResetDailyWindowStart() { + m.daily_window_start = nil + delete(m.clearedFields, userplatformquota.FieldDailyWindowStart) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) SetWeeklyWindowStart(t time.Time) { + m.weekly_window_start = &t +} + +// WeeklyWindowStart returns the value of the "weekly_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) WeeklyWindowStart() (r time.Time, exists bool) { + v := m.weekly_window_start + if v == nil { + return + } + return *v, true +} + +// OldWeeklyWindowStart returns the old "weekly_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldWeeklyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyWindowStart: %w", err) + } + return oldValue.WeeklyWindowStart, nil +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) ClearWeeklyWindowStart() { + m.weekly_window_start = nil + m.clearedFields[userplatformquota.FieldWeeklyWindowStart] = struct{}{} +} + +// WeeklyWindowStartCleared returns if the "weekly_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) WeeklyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldWeeklyWindowStart] + return ok +} + +// ResetWeeklyWindowStart resets all changes to the "weekly_window_start" field. +func (m *UserPlatformQuotaMutation) ResetWeeklyWindowStart() { + m.weekly_window_start = nil + delete(m.clearedFields, userplatformquota.FieldWeeklyWindowStart) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) SetMonthlyWindowStart(t time.Time) { + m.monthly_window_start = &t +} + +// MonthlyWindowStart returns the value of the "monthly_window_start" field in the mutation. +func (m *UserPlatformQuotaMutation) MonthlyWindowStart() (r time.Time, exists bool) { + v := m.monthly_window_start + if v == nil { + return + } + return *v, true +} + +// OldMonthlyWindowStart returns the old "monthly_window_start" field's value of the UserPlatformQuota entity. +// If the UserPlatformQuota object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserPlatformQuotaMutation) OldMonthlyWindowStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyWindowStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyWindowStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyWindowStart: %w", err) + } + return oldValue.MonthlyWindowStart, nil +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) ClearMonthlyWindowStart() { + m.monthly_window_start = nil + m.clearedFields[userplatformquota.FieldMonthlyWindowStart] = struct{}{} +} + +// MonthlyWindowStartCleared returns if the "monthly_window_start" field was cleared in this mutation. +func (m *UserPlatformQuotaMutation) MonthlyWindowStartCleared() bool { + _, ok := m.clearedFields[userplatformquota.FieldMonthlyWindowStart] + return ok +} + +// ResetMonthlyWindowStart resets all changes to the "monthly_window_start" field. +func (m *UserPlatformQuotaMutation) ResetMonthlyWindowStart() { + m.monthly_window_start = nil + delete(m.clearedFields, userplatformquota.FieldMonthlyWindowStart) +} + +// ClearUser clears the "user" edge to the User entity. +func (m *UserPlatformQuotaMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[userplatformquota.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *UserPlatformQuotaMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *UserPlatformQuotaMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *UserPlatformQuotaMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the UserPlatformQuotaMutation builder. +func (m *UserPlatformQuotaMutation) Where(ps ...predicate.UserPlatformQuota) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UserPlatformQuotaMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UserPlatformQuotaMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UserPlatformQuota, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UserPlatformQuotaMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UserPlatformQuotaMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UserPlatformQuota). +func (m *UserPlatformQuotaMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UserPlatformQuotaMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, userplatformquota.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, userplatformquota.FieldUpdatedAt) + } + if m.deleted_at != nil { + fields = append(fields, userplatformquota.FieldDeletedAt) + } + if m.user != nil { + fields = append(fields, userplatformquota.FieldUserID) + } + if m.platform != nil { + fields = append(fields, userplatformquota.FieldPlatform) + } + if m.daily_limit_usd != nil { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.weekly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.monthly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.daily_usage_usd != nil { + fields = append(fields, userplatformquota.FieldDailyUsageUsd) + } + if m.weekly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyUsageUsd) + } + if m.monthly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyUsageUsd) + } + if m.daily_window_start != nil { + fields = append(fields, userplatformquota.FieldDailyWindowStart) + } + if m.weekly_window_start != nil { + fields = append(fields, userplatformquota.FieldWeeklyWindowStart) + } + if m.monthly_window_start != nil { + fields = append(fields, userplatformquota.FieldMonthlyWindowStart) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UserPlatformQuotaMutation) Field(name string) (ent.Value, bool) { + switch name { + case userplatformquota.FieldCreatedAt: + return m.CreatedAt() + case userplatformquota.FieldUpdatedAt: + return m.UpdatedAt() + case userplatformquota.FieldDeletedAt: + return m.DeletedAt() + case userplatformquota.FieldUserID: + return m.UserID() + case userplatformquota.FieldPlatform: + return m.Platform() + case userplatformquota.FieldDailyLimitUsd: + return m.DailyLimitUsd() + case userplatformquota.FieldWeeklyLimitUsd: + return m.WeeklyLimitUsd() + case userplatformquota.FieldMonthlyLimitUsd: + return m.MonthlyLimitUsd() + case userplatformquota.FieldDailyUsageUsd: + return m.DailyUsageUsd() + case userplatformquota.FieldWeeklyUsageUsd: + return m.WeeklyUsageUsd() + case userplatformquota.FieldMonthlyUsageUsd: + return m.MonthlyUsageUsd() + case userplatformquota.FieldDailyWindowStart: + return m.DailyWindowStart() + case userplatformquota.FieldWeeklyWindowStart: + return m.WeeklyWindowStart() + case userplatformquota.FieldMonthlyWindowStart: + return m.MonthlyWindowStart() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UserPlatformQuotaMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case userplatformquota.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case userplatformquota.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case userplatformquota.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case userplatformquota.FieldUserID: + return m.OldUserID(ctx) + case userplatformquota.FieldPlatform: + return m.OldPlatform(ctx) + case userplatformquota.FieldDailyLimitUsd: + return m.OldDailyLimitUsd(ctx) + case userplatformquota.FieldWeeklyLimitUsd: + return m.OldWeeklyLimitUsd(ctx) + case userplatformquota.FieldMonthlyLimitUsd: + return m.OldMonthlyLimitUsd(ctx) + case userplatformquota.FieldDailyUsageUsd: + return m.OldDailyUsageUsd(ctx) + case userplatformquota.FieldWeeklyUsageUsd: + return m.OldWeeklyUsageUsd(ctx) + case userplatformquota.FieldMonthlyUsageUsd: + return m.OldMonthlyUsageUsd(ctx) + case userplatformquota.FieldDailyWindowStart: + return m.OldDailyWindowStart(ctx) + case userplatformquota.FieldWeeklyWindowStart: + return m.OldWeeklyWindowStart(ctx) + case userplatformquota.FieldMonthlyWindowStart: + return m.OldMonthlyWindowStart(ctx) + } + return nil, fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserPlatformQuotaMutation) SetField(name string, value ent.Value) error { + switch name { + case userplatformquota.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case userplatformquota.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case userplatformquota.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil + case userplatformquota.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case userplatformquota.FieldPlatform: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatform(v) + return nil + case userplatformquota.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyLimitUsd(v) + return nil + case userplatformquota.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyLimitUsd(v) + return nil + case userplatformquota.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyLimitUsd(v) + return nil + case userplatformquota.FieldDailyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyUsageUsd(v) + return nil + case userplatformquota.FieldWeeklyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyUsageUsd(v) + return nil + case userplatformquota.FieldMonthlyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyUsageUsd(v) + return nil + case userplatformquota.FieldDailyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDailyWindowStart(v) + return nil + case userplatformquota.FieldWeeklyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyWindowStart(v) + return nil + case userplatformquota.FieldMonthlyWindowStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyWindowStart(v) + return nil + } + return fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UserPlatformQuotaMutation) AddedFields() []string { + var fields []string + if m.adddaily_limit_usd != nil { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.addweekly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.addmonthly_limit_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.adddaily_usage_usd != nil { + fields = append(fields, userplatformquota.FieldDailyUsageUsd) + } + if m.addweekly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldWeeklyUsageUsd) + } + if m.addmonthly_usage_usd != nil { + fields = append(fields, userplatformquota.FieldMonthlyUsageUsd) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UserPlatformQuotaMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case userplatformquota.FieldDailyLimitUsd: + return m.AddedDailyLimitUsd() + case userplatformquota.FieldWeeklyLimitUsd: + return m.AddedWeeklyLimitUsd() + case userplatformquota.FieldMonthlyLimitUsd: + return m.AddedMonthlyLimitUsd() + case userplatformquota.FieldDailyUsageUsd: + return m.AddedDailyUsageUsd() + case userplatformquota.FieldWeeklyUsageUsd: + return m.AddedWeeklyUsageUsd() + case userplatformquota.FieldMonthlyUsageUsd: + return m.AddedMonthlyUsageUsd() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserPlatformQuotaMutation) AddField(name string, value ent.Value) error { + switch name { + case userplatformquota.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyLimitUsd(v) + return nil + case userplatformquota.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyLimitUsd(v) + return nil + case userplatformquota.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyLimitUsd(v) + return nil + case userplatformquota.FieldDailyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyUsageUsd(v) + return nil + case userplatformquota.FieldWeeklyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyUsageUsd(v) + return nil + case userplatformquota.FieldMonthlyUsageUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyUsageUsd(v) + return nil + } + return fmt.Errorf("unknown UserPlatformQuota numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UserPlatformQuotaMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(userplatformquota.FieldDeletedAt) { + fields = append(fields, userplatformquota.FieldDeletedAt) + } + if m.FieldCleared(userplatformquota.FieldDailyLimitUsd) { + fields = append(fields, userplatformquota.FieldDailyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldWeeklyLimitUsd) { + fields = append(fields, userplatformquota.FieldWeeklyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldMonthlyLimitUsd) { + fields = append(fields, userplatformquota.FieldMonthlyLimitUsd) + } + if m.FieldCleared(userplatformquota.FieldDailyWindowStart) { + fields = append(fields, userplatformquota.FieldDailyWindowStart) + } + if m.FieldCleared(userplatformquota.FieldWeeklyWindowStart) { + fields = append(fields, userplatformquota.FieldWeeklyWindowStart) + } + if m.FieldCleared(userplatformquota.FieldMonthlyWindowStart) { + fields = append(fields, userplatformquota.FieldMonthlyWindowStart) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UserPlatformQuotaMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UserPlatformQuotaMutation) ClearField(name string) error { + switch name { + case userplatformquota.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case userplatformquota.FieldDailyLimitUsd: + m.ClearDailyLimitUsd() + return nil + case userplatformquota.FieldWeeklyLimitUsd: + m.ClearWeeklyLimitUsd() + return nil + case userplatformquota.FieldMonthlyLimitUsd: + m.ClearMonthlyLimitUsd() + return nil + case userplatformquota.FieldDailyWindowStart: + m.ClearDailyWindowStart() + return nil + case userplatformquota.FieldWeeklyWindowStart: + m.ClearWeeklyWindowStart() + return nil + case userplatformquota.FieldMonthlyWindowStart: + m.ClearMonthlyWindowStart() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UserPlatformQuotaMutation) ResetField(name string) error { + switch name { + case userplatformquota.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case userplatformquota.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case userplatformquota.FieldDeletedAt: + m.ResetDeletedAt() + return nil + case userplatformquota.FieldUserID: + m.ResetUserID() + return nil + case userplatformquota.FieldPlatform: + m.ResetPlatform() + return nil + case userplatformquota.FieldDailyLimitUsd: + m.ResetDailyLimitUsd() + return nil + case userplatformquota.FieldWeeklyLimitUsd: + m.ResetWeeklyLimitUsd() + return nil + case userplatformquota.FieldMonthlyLimitUsd: + m.ResetMonthlyLimitUsd() + return nil + case userplatformquota.FieldDailyUsageUsd: + m.ResetDailyUsageUsd() + return nil + case userplatformquota.FieldWeeklyUsageUsd: + m.ResetWeeklyUsageUsd() + return nil + case userplatformquota.FieldMonthlyUsageUsd: + m.ResetMonthlyUsageUsd() + return nil + case userplatformquota.FieldDailyWindowStart: + m.ResetDailyWindowStart() + return nil + case userplatformquota.FieldWeeklyWindowStart: + m.ResetWeeklyWindowStart() + return nil + case userplatformquota.FieldMonthlyWindowStart: + m.ResetMonthlyWindowStart() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UserPlatformQuotaMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, userplatformquota.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UserPlatformQuotaMutation) AddedIDs(name string) []ent.Value { + switch name { + case userplatformquota.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UserPlatformQuotaMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UserPlatformQuotaMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UserPlatformQuotaMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, userplatformquota.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UserPlatformQuotaMutation) EdgeCleared(name string) bool { + switch name { + case userplatformquota.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UserPlatformQuotaMutation) ClearEdge(name string) error { + switch name { + case userplatformquota.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UserPlatformQuotaMutation) ResetEdge(name string) error { + switch name { + case userplatformquota.EdgeUser: + m.ResetUser() + return nil + } + return fmt.Errorf("unknown UserPlatformQuota edge %s", name) +} + // UserSubscriptionMutation represents an operation that mutates the UserSubscription nodes in the graph. type UserSubscriptionMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index dc86471e..ab4d7d18 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector) // UserAttributeValue is the predicate function for userattributevalue builders. type UserAttributeValue func(*sql.Selector) +// UserPlatformQuota is the predicate function for userplatformquota builders. +type UserPlatformQuota func(*sql.Selector) + // UserSubscription is the predicate function for usersubscription builders. type UserSubscription func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6d541e2f..aa6130f0 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -39,6 +39,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -1997,6 +1998,56 @@ func init() { userattributevalueDescValue := userattributevalueFields[2].Descriptor() // userattributevalue.DefaultValue holds the default value on creation for the value field. userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string) + userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin() + userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks() + userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0] + userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors() + userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0] + userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields() + _ = userplatformquotaMixinFields0 + userplatformquotaFields := schema.UserPlatformQuota{}.Fields() + _ = userplatformquotaFields + // userplatformquotaDescCreatedAt is the schema descriptor for created_at field. + userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor() + // userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field. + userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time) + // userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field. + userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor() + // userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field. + userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time) + // userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time) + // userplatformquotaDescPlatform is the schema descriptor for platform field. + userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor() + // userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save. + userplatformquota.PlatformValidator = func() func(string) error { + validators := userplatformquotaDescPlatform.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(platform string) error { + for _, fn := range fns { + if err := fn(platform); err != nil { + return err + } + } + return nil + } + }() + // userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field. + userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor() + // userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field. + userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64) + // userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field. + userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor() + // userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field. + userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64) + // userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field. + userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor() + // userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field. + userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64) usersubscriptionMixin := schema.UserSubscription{}.Mixin() usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks() usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0] diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index c6e04273..127b5af9 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge { edge.To("auth_identities", AuthIdentity.Type). Annotations(entsql.OnDelete(entsql.Cascade)), edge.To("pending_auth_sessions", PendingAuthSession.Type), + edge.To("platform_quotas", UserPlatformQuota.Type), } } diff --git a/backend/ent/schema/user_platform_quota.go b/backend/ent/schema/user_platform_quota.go new file mode 100644 index 00000000..8fd8acc0 --- /dev/null +++ b/backend/ent/schema/user_platform_quota.go @@ -0,0 +1,113 @@ +package schema + +import ( + "fmt" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" +) + +// UserPlatformQuota holds the schema definition for per-user per-platform quota. +type UserPlatformQuota struct { + ent.Schema +} + +func (UserPlatformQuota) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "user_platform_quotas"}, + } +} + +func (UserPlatformQuota) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + mixins.SoftDeleteMixin{}, + } +} + +func (UserPlatformQuota) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("platform"). + MaxLen(32). + NotEmpty(). + Validate(func(s string) error { + // 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms; + // 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。 + switch s { + case "anthropic", "openai", "gemini", "antigravity": + return nil + default: + return fmt.Errorf("platform %q is not allowed", s) + } + }), + + // 日 / 周 / 月 USD 上限: + // nil / not set → 无限额(完全放行) + // 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立) + // > 0 → USD 限额上限 + field.Float("daily_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("weekly_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("monthly_limit_usd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + + // 当前窗口已用量(USD,preflight 时与 limit 比较) + field.Float("daily_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("weekly_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("monthly_usage_usd"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + + // 窗口起点(NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底) + field.Time("daily_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("weekly_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("monthly_window_start"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (UserPlatformQuota) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("platform_quotas"). + Field("user_id"). + Unique(). + Required(), + } +} + +func (UserPlatformQuota) Indexes() []ent.Index { + return []ent.Index{ + // 软删除友好:只对未删记录唯一 + index.Fields("user_id", "platform"). + Unique(). + Annotations(entsql.IndexWhere("deleted_at IS NULL")), + index.Fields("user_id"), + } +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 611028e9..846cfcd4 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -80,6 +80,8 @@ type Tx struct { UserAttributeDefinition *UserAttributeDefinitionClient // UserAttributeValue is the client for interacting with the UserAttributeValue builders. UserAttributeValue *UserAttributeValueClient + // UserPlatformQuota is the client for interacting with the UserPlatformQuota builders. + UserPlatformQuota *UserPlatformQuotaClient // UserSubscription is the client for interacting with the UserSubscription builders. UserSubscription *UserSubscriptionClient @@ -246,6 +248,7 @@ func (tx *Tx) init() { tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config) tx.UserAttributeValue = NewUserAttributeValueClient(tx.config) + tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config) tx.UserSubscription = NewUserSubscriptionClient(tx.config) } diff --git a/backend/ent/user.go b/backend/ent/user.go index 06670444..486f2f64 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -95,11 +95,13 @@ type UserEdges struct { AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` // PendingAuthSessions holds the value of the pending_auth_sessions edge. PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` + // PlatformQuotas holds the value of the platform_quotas edge. + PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [13]bool + loadedTypes [14]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { return nil, &NotLoadedError{edge: "pending_auth_sessions"} } +// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) { + if e.loadedTypes[12] { + return e.PlatformQuotas, nil + } + return nil, &NotLoadedError{edge: "platform_quotas"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[12] { + if e.loadedTypes[13] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { return NewUserClient(_m.config).QueryPendingAuthSessions(_m) } +// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity. +func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery { + return NewUserClient(_m.config).QueryPlatformQuotas(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index e11a8a32..ff40445b 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -85,6 +85,8 @@ const ( EdgeAuthIdentities = "auth_identities" // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. EdgePendingAuthSessions = "pending_auth_sessions" + // EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations. + EdgePlatformQuotas = "platform_quotas" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -171,6 +173,13 @@ const ( PendingAuthSessionsInverseTable = "pending_auth_sessions" // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. PendingAuthSessionsColumn = "target_user_id" + // PlatformQuotasTable is the table that holds the platform_quotas relation/edge. + PlatformQuotasTable = "user_platform_quotas" + // PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity. + // It exists in this package in order to avoid circular dependency with the "userplatformquota" package. + PlatformQuotasInverseTable = "user_platform_quotas" + // PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge. + PlatformQuotasColumn = "user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti } } +// ByPlatformQuotasCount orders the results by platform_quotas count. +func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...) + } +} + +// ByPlatformQuotas orders the results by platform_quotas terms. +func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), ) } +func newPlatformQuotasStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PlatformQuotasInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 05d3b35b..a18cf497 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate }) } +// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge. +func HasPlatformQuotas() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates). +func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPlatformQuotasStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index b4161128..92f1bd5e 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea return _c.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate { + _c.mutation.AddPlatformQuotaIDs(ids...) + return _c +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index f1ee5cfe..86f4eccd 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -48,6 +49,7 @@ type UserQuery struct { withPaymentOrders *PaymentOrderQuery withAuthIdentities *AuthIdentityQuery withPendingAuthSessions *PendingAuthSessionQuery + withPlatformQuotas *UserPlatformQuotaQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { return query } +// QueryPlatformQuotas chains the current query on the "platform_quotas" edge. +func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery { + query := (&UserPlatformQuotaClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery { withPaymentOrders: _q.withPaymentOrders.Clone(), withAuthIdentities: _q.withAuthIdentities.Clone(), withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), + withPlatformQuotas: _q.withPlatformQuotas.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue return _q } +// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to +// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery { + query := (&UserPlatformQuotaClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPlatformQuotas = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [13]bool{ + loadedTypes = [14]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withPaymentOrders != nil, _q.withAuthIdentities != nil, _q.withPendingAuthSessions != nil, + _q.withPlatformQuotas != nil, _q.withUserAllowedGroups != nil, } ) @@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withPlatformQuotas; query != nil { + if err := _q.loadPlatformQuotas(ctx, query, nodes, + func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} }, + func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending } return nil } +func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(userplatformquota.FieldUserID) + } + query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index f1d759ce..67d3f8e6 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -23,6 +23,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda return _u.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPlatformQuotaIDs(ids...) + return _u +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU return _u.RemovePendingAuthSessionIDs(ids...) } +// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate { + _u.mutation.ClearPlatformQuotas() + return _u +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs. +func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePlatformQuotaIDs(ids...) + return _u +} + +// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities. +func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePlatformQuotaIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU return _u.AddPendingAuthSessionIDs(ids...) } +// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs. +func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPlatformQuotaIDs(ids...) + return _u +} + +// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPlatformQuotaIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us return _u.RemovePendingAuthSessionIDs(ids...) } +// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity. +func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne { + _u.mutation.ClearPlatformQuotas() + return _u +} + +// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs. +func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePlatformQuotaIDs(ids...) + return _u +} + +// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities. +func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePlatformQuotaIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PlatformQuotasTable, + Columns: []string{user.PlatformQuotasColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/userplatformquota.go b/backend/ent/userplatformquota.go new file mode 100644 index 00000000..d00a1f9a --- /dev/null +++ b/backend/ent/userplatformquota.go @@ -0,0 +1,301 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuota is the model entity for the UserPlatformQuota schema. +type UserPlatformQuota struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // Platform holds the value of the "platform" field. + Platform string `json:"platform,omitempty"` + // DailyLimitUsd holds the value of the "daily_limit_usd" field. + DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"` + // WeeklyLimitUsd holds the value of the "weekly_limit_usd" field. + WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"` + // MonthlyLimitUsd holds the value of the "monthly_limit_usd" field. + MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` + // DailyUsageUsd holds the value of the "daily_usage_usd" field. + DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"` + // WeeklyUsageUsd holds the value of the "weekly_usage_usd" field. + WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"` + // MonthlyUsageUsd holds the value of the "monthly_usage_usd" field. + MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"` + // DailyWindowStart holds the value of the "daily_window_start" field. + DailyWindowStart *time.Time `json:"daily_window_start,omitempty"` + // WeeklyWindowStart holds the value of the "weekly_window_start" field. + WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"` + // MonthlyWindowStart holds the value of the "monthly_window_start" field. + MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the UserPlatformQuotaQuery when eager-loading is set. + Edges UserPlatformQuotaEdges `json:"edges"` + selectValues sql.SelectValues +} + +// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph. +type UserPlatformQuotaEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd: + values[i] = new(sql.NullFloat64) + case userplatformquota.FieldID, userplatformquota.FieldUserID: + values[i] = new(sql.NullInt64) + case userplatformquota.FieldPlatform: + values[i] = new(sql.NullString) + case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UserPlatformQuota fields. +func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case userplatformquota.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case userplatformquota.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case userplatformquota.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case userplatformquota.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + _m.DeletedAt = new(time.Time) + *_m.DeletedAt = value.Time + } + case userplatformquota.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case userplatformquota.FieldPlatform: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field platform", values[i]) + } else if value.Valid { + _m.Platform = value.String + } + case userplatformquota.FieldDailyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i]) + } else if value.Valid { + _m.DailyLimitUsd = new(float64) + *_m.DailyLimitUsd = value.Float64 + } + case userplatformquota.FieldWeeklyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i]) + } else if value.Valid { + _m.WeeklyLimitUsd = new(float64) + *_m.WeeklyLimitUsd = value.Float64 + } + case userplatformquota.FieldMonthlyLimitUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i]) + } else if value.Valid { + _m.MonthlyLimitUsd = new(float64) + *_m.MonthlyLimitUsd = value.Float64 + } + case userplatformquota.FieldDailyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i]) + } else if value.Valid { + _m.DailyUsageUsd = value.Float64 + } + case userplatformquota.FieldWeeklyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i]) + } else if value.Valid { + _m.WeeklyUsageUsd = value.Float64 + } + case userplatformquota.FieldMonthlyUsageUsd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i]) + } else if value.Valid { + _m.MonthlyUsageUsd = value.Float64 + } + case userplatformquota.FieldDailyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field daily_window_start", values[i]) + } else if value.Valid { + _m.DailyWindowStart = new(time.Time) + *_m.DailyWindowStart = value.Time + } + case userplatformquota.FieldWeeklyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i]) + } else if value.Valid { + _m.WeeklyWindowStart = new(time.Time) + *_m.WeeklyWindowStart = value.Time + } + case userplatformquota.FieldMonthlyWindowStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i]) + } else if value.Valid { + _m.MonthlyWindowStart = new(time.Time) + *_m.MonthlyWindowStart = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota. +// This includes values selected through modifiers, order, etc. +func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the UserPlatformQuota entity. +func (_m *UserPlatformQuota) QueryUser() *UserQuery { + return NewUserPlatformQuotaClient(_m.config).QueryUser(_m) +} + +// Update returns a builder for updating this UserPlatformQuota. +// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne { + return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: UserPlatformQuota is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UserPlatformQuota) String() string { + var builder strings.Builder + builder.WriteString("UserPlatformQuota(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("platform=") + builder.WriteString(_m.Platform) + builder.WriteString(", ") + if v := _m.DailyLimitUsd; v != nil { + builder.WriteString("daily_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.WeeklyLimitUsd; v != nil { + builder.WriteString("weekly_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.MonthlyLimitUsd; v != nil { + builder.WriteString("monthly_limit_usd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("daily_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd)) + builder.WriteString(", ") + builder.WriteString("weekly_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd)) + builder.WriteString(", ") + builder.WriteString("monthly_usage_usd=") + builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd)) + builder.WriteString(", ") + if v := _m.DailyWindowStart; v != nil { + builder.WriteString("daily_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.WeeklyWindowStart; v != nil { + builder.WriteString("weekly_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.MonthlyWindowStart; v != nil { + builder.WriteString("monthly_window_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota. +type UserPlatformQuotaSlice []*UserPlatformQuota diff --git a/backend/ent/userplatformquota/userplatformquota.go b/backend/ent/userplatformquota/userplatformquota.go new file mode 100644 index 00000000..01903853 --- /dev/null +++ b/backend/ent/userplatformquota/userplatformquota.go @@ -0,0 +1,202 @@ +// Code generated by ent, DO NOT EDIT. + +package userplatformquota + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the userplatformquota type in the database. + Label = "user_platform_quota" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldPlatform holds the string denoting the platform field in the database. + FieldPlatform = "platform" + // FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database. + FieldDailyLimitUsd = "daily_limit_usd" + // FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database. + FieldWeeklyLimitUsd = "weekly_limit_usd" + // FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database. + FieldMonthlyLimitUsd = "monthly_limit_usd" + // FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database. + FieldDailyUsageUsd = "daily_usage_usd" + // FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database. + FieldWeeklyUsageUsd = "weekly_usage_usd" + // FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database. + FieldMonthlyUsageUsd = "monthly_usage_usd" + // FieldDailyWindowStart holds the string denoting the daily_window_start field in the database. + FieldDailyWindowStart = "daily_window_start" + // FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database. + FieldWeeklyWindowStart = "weekly_window_start" + // FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database. + FieldMonthlyWindowStart = "monthly_window_start" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // Table holds the table name of the userplatformquota in the database. + Table = "user_platform_quotas" + // UserTable is the table that holds the user relation/edge. + UserTable = "user_platform_quotas" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" +) + +// Columns holds all SQL columns for userplatformquota fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldDeletedAt, + FieldUserID, + FieldPlatform, + FieldDailyLimitUsd, + FieldWeeklyLimitUsd, + FieldMonthlyLimitUsd, + FieldDailyUsageUsd, + FieldWeeklyUsageUsd, + FieldMonthlyUsageUsd, + FieldDailyWindowStart, + FieldWeeklyWindowStart, + FieldMonthlyWindowStart, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/Wei-Shaw/sub2api/ent/runtime" +var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // PlatformValidator is a validator for the "platform" field. It is called by the builders before save. + PlatformValidator func(string) error + // DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field. + DefaultDailyUsageUsd float64 + // DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field. + DefaultWeeklyUsageUsd float64 + // DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field. + DefaultMonthlyUsageUsd float64 +) + +// OrderOption defines the ordering options for the UserPlatformQuota queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByPlatform orders the results by the platform field. +func ByPlatform(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPlatform, opts...).ToFunc() +} + +// ByDailyLimitUsd orders the results by the daily_limit_usd field. +func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc() +} + +// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field. +func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc() +} + +// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field. +func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc() +} + +// ByDailyUsageUsd orders the results by the daily_usage_usd field. +func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc() +} + +// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field. +func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc() +} + +// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field. +func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc() +} + +// ByDailyWindowStart orders the results by the daily_window_start field. +func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc() +} + +// ByWeeklyWindowStart orders the results by the weekly_window_start field. +func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc() +} + +// ByMonthlyWindowStart orders the results by the monthly_window_start field. +func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} diff --git a/backend/ent/userplatformquota/where.go b/backend/ent/userplatformquota/where.go new file mode 100644 index 00000000..37d371c6 --- /dev/null +++ b/backend/ent/userplatformquota/where.go @@ -0,0 +1,799 @@ +// Code generated by ent, DO NOT EDIT. + +package userplatformquota + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v)) +} + +// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ. +func Platform(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v)) +} + +// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ. +func DailyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v)) +} + +// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ. +func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v)) +} + +// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ. +func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v)) +} + +// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ. +func DailyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v)) +} + +// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ. +func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v)) +} + +// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ. +func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v)) +} + +// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ. +func DailyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v)) +} + +// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ. +func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v)) +} + +// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ. +func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...)) +} + +// PlatformEQ applies the EQ predicate on the "platform" field. +func PlatformEQ(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v)) +} + +// PlatformNEQ applies the NEQ predicate on the "platform" field. +func PlatformNEQ(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v)) +} + +// PlatformIn applies the In predicate on the "platform" field. +func PlatformIn(vs ...string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...)) +} + +// PlatformNotIn applies the NotIn predicate on the "platform" field. +func PlatformNotIn(vs ...string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...)) +} + +// PlatformGT applies the GT predicate on the "platform" field. +func PlatformGT(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v)) +} + +// PlatformGTE applies the GTE predicate on the "platform" field. +func PlatformGTE(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v)) +} + +// PlatformLT applies the LT predicate on the "platform" field. +func PlatformLT(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v)) +} + +// PlatformLTE applies the LTE predicate on the "platform" field. +func PlatformLTE(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v)) +} + +// PlatformContains applies the Contains predicate on the "platform" field. +func PlatformContains(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v)) +} + +// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field. +func PlatformHasPrefix(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v)) +} + +// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field. +func PlatformHasSuffix(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v)) +} + +// PlatformEqualFold applies the EqualFold predicate on the "platform" field. +func PlatformEqualFold(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v)) +} + +// PlatformContainsFold applies the ContainsFold predicate on the "platform" field. +func PlatformContainsFold(v string) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v)) +} + +// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field. +func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field. +func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field. +func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...)) +} + +// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field. +func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...)) +} + +// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field. +func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field. +func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field. +func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field. +func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v)) +} + +// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field. +func DailyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd)) +} + +// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field. +func DailyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd)) +} + +// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...)) +} + +// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...)) +} + +// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v)) +} + +// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd)) +} + +// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field. +func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd)) +} + +// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...)) +} + +// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...)) +} + +// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v)) +} + +// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd)) +} + +// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field. +func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd)) +} + +// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field. +func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field. +func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field. +func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...)) +} + +// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field. +func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...)) +} + +// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field. +func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field. +func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field. +func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v)) +} + +// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field. +func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v)) +} + +// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...)) +} + +// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...)) +} + +// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v)) +} + +// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field. +func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v)) +} + +// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...)) +} + +// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...)) +} + +// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v)) +} + +// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field. +func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v)) +} + +// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field. +func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v)) +} + +// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field. +func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v)) +} + +// DailyWindowStartIn applies the In predicate on the "daily_window_start" field. +func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...)) +} + +// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field. +func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...)) +} + +// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field. +func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v)) +} + +// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field. +func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v)) +} + +// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field. +func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v)) +} + +// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field. +func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v)) +} + +// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field. +func DailyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart)) +} + +// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field. +func DailyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart)) +} + +// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field. +func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field. +func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field. +func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...)) +} + +// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field. +func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...)) +} + +// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field. +func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field. +func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field. +func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field. +func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v)) +} + +// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field. +func WeeklyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart)) +} + +// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field. +func WeeklyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart)) +} + +// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field. +func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field. +func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field. +func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...)) +} + +// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field. +func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...)) +} + +// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field. +func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field. +func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field. +func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field. +func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v)) +} + +// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field. +func MonthlyWindowStartIsNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart)) +} + +// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field. +func MonthlyWindowStartNotNil() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota { + return predicate.UserPlatformQuota(sql.NotPredicates(p)) +} diff --git a/backend/ent/userplatformquota_create.go b/backend/ent/userplatformquota_create.go new file mode 100644 index 00000000..da6c3ce6 --- /dev/null +++ b/backend/ent/userplatformquota_create.go @@ -0,0 +1,1513 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaCreate is the builder for creating a UserPlatformQuota entity. +type UserPlatformQuotaCreate struct { + config + mutation *UserPlatformQuotaMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UserPlatformQuotaCreate) SetCreatedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableCreatedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *UserPlatformQuotaCreate) SetUpdatedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableUpdatedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetDeletedAt sets the "deleted_at" field. +func (_c *UserPlatformQuotaCreate) SetDeletedAt(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetDeletedAt(v) + return _c +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDeletedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *UserPlatformQuotaCreate) SetUserID(v int64) *UserPlatformQuotaCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetPlatform sets the "platform" field. +func (_c *UserPlatformQuotaCreate) SetPlatform(v string) *UserPlatformQuotaCreate { + _c.mutation.SetPlatform(v) + return _c +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetDailyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetDailyLimitUsd(v) + return _c +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyLimitUsd(*v) + } + return _c +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyLimitUsd(v) + return _c +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyLimitUsd(*v) + } + return _c +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyLimitUsd(v) + return _c +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyLimitUsd(*v) + } + return _c +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetDailyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetDailyUsageUsd(v) + return _c +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyUsageUsd(*v) + } + return _c +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyUsageUsd(v) + return _c +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyUsageUsd(*v) + } + return _c +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyUsageUsd(v) + return _c +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyUsageUsd(*v) + } + return _c +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_c *UserPlatformQuotaCreate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetDailyWindowStart(v) + return _c +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetDailyWindowStart(*v) + } + return _c +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_c *UserPlatformQuotaCreate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetWeeklyWindowStart(v) + return _c +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetWeeklyWindowStart(*v) + } + return _c +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_c *UserPlatformQuotaCreate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaCreate { + _c.mutation.SetMonthlyWindowStart(v) + return _c +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_c *UserPlatformQuotaCreate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaCreate { + if v != nil { + _c.SetMonthlyWindowStart(*v) + } + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *UserPlatformQuotaCreate) SetUser(v *User) *UserPlatformQuotaCreate { + return _c.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_c *UserPlatformQuotaCreate) Mutation() *UserPlatformQuotaMutation { + return _c.mutation +} + +// Save creates the UserPlatformQuota in the database. +func (_c *UserPlatformQuotaCreate) Save(ctx context.Context) (*UserPlatformQuota, error) { + if err := _c.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UserPlatformQuotaCreate) SaveX(ctx context.Context) *UserPlatformQuota { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserPlatformQuotaCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserPlatformQuotaCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UserPlatformQuotaCreate) defaults() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + if userplatformquota.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.DefaultCreatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + if userplatformquota.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.DefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.DailyUsageUsd(); !ok { + v := userplatformquota.DefaultDailyUsageUsd + _c.mutation.SetDailyUsageUsd(v) + } + if _, ok := _c.mutation.WeeklyUsageUsd(); !ok { + v := userplatformquota.DefaultWeeklyUsageUsd + _c.mutation.SetWeeklyUsageUsd(v) + } + if _, ok := _c.mutation.MonthlyUsageUsd(); !ok { + v := userplatformquota.DefaultMonthlyUsageUsd + _c.mutation.SetMonthlyUsageUsd(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UserPlatformQuotaCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UserPlatformQuota.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UserPlatformQuota.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UserPlatformQuota.user_id"`)} + } + if _, ok := _c.mutation.Platform(); !ok { + return &ValidationError{Name: "platform", err: errors.New(`ent: missing required field "UserPlatformQuota.platform"`)} + } + if v, ok := _c.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _, ok := _c.mutation.DailyUsageUsd(); !ok { + return &ValidationError{Name: "daily_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.daily_usage_usd"`)} + } + if _, ok := _c.mutation.WeeklyUsageUsd(); !ok { + return &ValidationError{Name: "weekly_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.weekly_usage_usd"`)} + } + if _, ok := _c.mutation.MonthlyUsageUsd(); !ok { + return &ValidationError{Name: "monthly_usage_usd", err: errors.New(`ent: missing required field "UserPlatformQuota.monthly_usage_usd"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UserPlatformQuota.user"`)} + } + return nil +} + +func (_c *UserPlatformQuotaCreate) sqlSave(ctx context.Context) (*UserPlatformQuota, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UserPlatformQuotaCreate) createSpec() (*UserPlatformQuota, *sqlgraph.CreateSpec) { + var ( + _node = &UserPlatformQuota{config: _c.config} + _spec = sqlgraph.NewCreateSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(userplatformquota.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } + if value, ok := _c.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + _node.Platform = value + } + if value, ok := _c.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + _node.DailyLimitUsd = &value + } + if value, ok := _c.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + _node.WeeklyLimitUsd = &value + } + if value, ok := _c.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + _node.MonthlyLimitUsd = &value + } + if value, ok := _c.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + _node.DailyUsageUsd = value + } + if value, ok := _c.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + _node.WeeklyUsageUsd = value + } + if value, ok := _c.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + _node.MonthlyUsageUsd = value + } + if value, ok := _c.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + _node.DailyWindowStart = &value + } + if value, ok := _c.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + _node.WeeklyWindowStart = &value + } + if value, ok := _c.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + _node.MonthlyWindowStart = &value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UserPlatformQuota.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserPlatformQuotaUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UserPlatformQuotaCreate) OnConflict(opts ...sql.ConflictOption) *UserPlatformQuotaUpsertOne { + _c.conflict = opts + return &UserPlatformQuotaUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UserPlatformQuotaCreate) OnConflictColumns(columns ...string) *UserPlatformQuotaUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UserPlatformQuotaUpsertOne{ + create: _c, + } +} + +type ( + // UserPlatformQuotaUpsertOne is the builder for "upsert"-ing + // one UserPlatformQuota node. + UserPlatformQuotaUpsertOne struct { + create *UserPlatformQuotaCreate + } + + // UserPlatformQuotaUpsert is the "OnConflict" setter. + UserPlatformQuotaUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsert) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateUpdatedAt() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldUpdatedAt) + return u +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsert) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDeletedAt() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsert) ClearDeletedAt() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDeletedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsert) SetUserID(v int64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateUserID() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldUserID) + return u +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsert) SetPlatform(v string) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldPlatform, v) + return u +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdatePlatform() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldPlatform) + return u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyLimitUsd, v) + return u +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyLimitUsd) + return u +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldDailyLimitUsd, v) + return u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearDailyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDailyLimitUsd) + return u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyLimitUsd, v) + return u +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyLimitUsd) + return u +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldWeeklyLimitUsd, v) + return u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldWeeklyLimitUsd) + return u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyLimitUsd, v) + return u +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyLimitUsd) + return u +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldMonthlyLimitUsd, v) + return u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsert) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldMonthlyLimitUsd) + return u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyUsageUsd, v) + return u +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyUsageUsd) + return u +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldDailyUsageUsd, v) + return u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyUsageUsd, v) + return u +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyUsageUsd) + return u +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldWeeklyUsageUsd, v) + return u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyUsageUsd, v) + return u +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyUsageUsd) + return u +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsert) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsert { + u.Add(userplatformquota.FieldMonthlyUsageUsd, v) + return u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsert) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldDailyWindowStart, v) + return u +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateDailyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldDailyWindowStart) + return u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearDailyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldDailyWindowStart) + return u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsert) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldWeeklyWindowStart, v) + return u +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldWeeklyWindowStart) + return u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearWeeklyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldWeeklyWindowStart) + return u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsert) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsert { + u.Set(userplatformquota.FieldMonthlyWindowStart, v) + return u +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsert) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsert { + u.SetExcluded(userplatformquota.FieldMonthlyWindowStart) + return u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsert) ClearMonthlyWindowStart() *UserPlatformQuotaUpsert { + u.SetNull(userplatformquota.FieldMonthlyWindowStart) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertOne) UpdateNewValues() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(userplatformquota.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertOne) Ignore() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserPlatformQuotaUpsertOne) DoNothing() *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserPlatformQuotaCreate.OnConflict +// documentation for more info. +func (u *UserPlatformQuotaUpsertOne) Update(set func(*UserPlatformQuotaUpsert)) *UserPlatformQuotaUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserPlatformQuotaUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsertOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateUpdatedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsertOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDeletedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsertOne) ClearDeletedAt() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsertOne) SetUserID(v int64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateUserID() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUserID() + }) +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsertOne) SetPlatform(v string) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetPlatform(v) + }) +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdatePlatform() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdatePlatform() + }) +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyLimitUsd(v) + }) +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyLimitUsd(v) + }) +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyLimitUsd() + }) +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearDailyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyLimitUsd() + }) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyLimitUsd(v) + }) +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyLimitUsd(v) + }) +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyLimitUsd() + }) +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyLimitUsd() + }) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyLimitUsd(v) + }) +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyLimitUsd(v) + }) +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyLimitUsd() + }) +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyLimitUsd() + }) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyUsageUsd(v) + }) +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyUsageUsd(v) + }) +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyUsageUsd() + }) +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyUsageUsd(v) + }) +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyUsageUsd(v) + }) +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyUsageUsd() + }) +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyUsageUsd(v) + }) +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyUsageUsd(v) + }) +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyUsageUsd() + }) +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyWindowStart(v) + }) +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateDailyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyWindowStart() + }) +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearDailyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyWindowStart() + }) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyWindowStart(v) + }) +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyWindowStart() + }) +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyWindowStart() + }) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyWindowStart(v) + }) +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertOne) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyWindowStart() + }) +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpsertOne { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyWindowStart() + }) +} + +// Exec executes the query. +func (u *UserPlatformQuotaUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserPlatformQuotaCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UserPlatformQuotaUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// UserPlatformQuotaCreateBulk is the builder for creating many UserPlatformQuota entities in bulk. +type UserPlatformQuotaCreateBulk struct { + config + err error + builders []*UserPlatformQuotaCreate + conflict []sql.ConflictOption +} + +// Save creates the UserPlatformQuota entities in the database. +func (_c *UserPlatformQuotaCreateBulk) Save(ctx context.Context) ([]*UserPlatformQuota, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UserPlatformQuota, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UserPlatformQuotaMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UserPlatformQuotaCreateBulk) SaveX(ctx context.Context) []*UserPlatformQuota { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserPlatformQuotaCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserPlatformQuotaCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UserPlatformQuota.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UserPlatformQuotaUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UserPlatformQuotaCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserPlatformQuotaUpsertBulk { + _c.conflict = opts + return &UserPlatformQuotaUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UserPlatformQuotaCreateBulk) OnConflictColumns(columns ...string) *UserPlatformQuotaUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UserPlatformQuotaUpsertBulk{ + create: _c, + } +} + +// UserPlatformQuotaUpsertBulk is the builder for "upsert"-ing +// a bulk of UserPlatformQuota nodes. +type UserPlatformQuotaUpsertBulk struct { + create *UserPlatformQuotaCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertBulk) UpdateNewValues() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(userplatformquota.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UserPlatformQuota.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UserPlatformQuotaUpsertBulk) Ignore() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UserPlatformQuotaUpsertBulk) DoNothing() *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UserPlatformQuotaCreateBulk.OnConflict +// documentation for more info. +func (u *UserPlatformQuotaUpsertBulk) Update(set func(*UserPlatformQuotaUpsert)) *UserPlatformQuotaUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UserPlatformQuotaUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UserPlatformQuotaUpsertBulk) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateUpdatedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetDeletedAt sets the "deleted_at" field. +func (u *UserPlatformQuotaUpsertBulk) SetDeletedAt(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDeletedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDeletedAt() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDeletedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *UserPlatformQuotaUpsertBulk) SetUserID(v int64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateUserID() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateUserID() + }) +} + +// SetPlatform sets the "platform" field. +func (u *UserPlatformQuotaUpsertBulk) SetPlatform(v string) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetPlatform(v) + }) +} + +// UpdatePlatform sets the "platform" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdatePlatform() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdatePlatform() + }) +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyLimitUsd(v) + }) +} + +// AddDailyLimitUsd adds v to the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyLimitUsd(v) + }) +} + +// UpdateDailyLimitUsd sets the "daily_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyLimitUsd() + }) +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDailyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyLimitUsd() + }) +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyLimitUsd(v) + }) +} + +// AddWeeklyLimitUsd adds v to the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyLimitUsd(v) + }) +} + +// UpdateWeeklyLimitUsd sets the "weekly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyLimitUsd() + }) +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearWeeklyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyLimitUsd() + }) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyLimitUsd(v) + }) +} + +// AddMonthlyLimitUsd adds v to the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyLimitUsd(v) + }) +} + +// UpdateMonthlyLimitUsd sets the "monthly_limit_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyLimitUsd() + }) +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (u *UserPlatformQuotaUpsertBulk) ClearMonthlyLimitUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyLimitUsd() + }) +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyUsageUsd(v) + }) +} + +// AddDailyUsageUsd adds v to the "daily_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddDailyUsageUsd(v) + }) +} + +// UpdateDailyUsageUsd sets the "daily_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyUsageUsd() + }) +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyUsageUsd(v) + }) +} + +// AddWeeklyUsageUsd adds v to the "weekly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddWeeklyUsageUsd(v) + }) +} + +// UpdateWeeklyUsageUsd sets the "weekly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyUsageUsd() + }) +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyUsageUsd(v) + }) +} + +// AddMonthlyUsageUsd adds v to the "monthly_usage_usd" field. +func (u *UserPlatformQuotaUpsertBulk) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.AddMonthlyUsageUsd(v) + }) +} + +// UpdateMonthlyUsageUsd sets the "monthly_usage_usd" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyUsageUsd() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyUsageUsd() + }) +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetDailyWindowStart(v) + }) +} + +// UpdateDailyWindowStart sets the "daily_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateDailyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateDailyWindowStart() + }) +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearDailyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearDailyWindowStart() + }) +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetWeeklyWindowStart(v) + }) +} + +// UpdateWeeklyWindowStart sets the "weekly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateWeeklyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateWeeklyWindowStart() + }) +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearWeeklyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearWeeklyWindowStart() + }) +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.SetMonthlyWindowStart(v) + }) +} + +// UpdateMonthlyWindowStart sets the "monthly_window_start" field to the value that was provided on create. +func (u *UserPlatformQuotaUpsertBulk) UpdateMonthlyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.UpdateMonthlyWindowStart() + }) +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (u *UserPlatformQuotaUpsertBulk) ClearMonthlyWindowStart() *UserPlatformQuotaUpsertBulk { + return u.Update(func(s *UserPlatformQuotaUpsert) { + s.ClearMonthlyWindowStart() + }) +} + +// Exec executes the query. +func (u *UserPlatformQuotaUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UserPlatformQuotaCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UserPlatformQuotaCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UserPlatformQuotaUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/userplatformquota_delete.go b/backend/ent/userplatformquota_delete.go new file mode 100644 index 00000000..1d80e31b --- /dev/null +++ b/backend/ent/userplatformquota_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity. +type UserPlatformQuotaDelete struct { + config + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// Where appends a list predicates to the UserPlatformQuotaDelete builder. +func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity. +type UserPlatformQuotaDeleteOne struct { + _d *UserPlatformQuotaDelete +} + +// Where appends a list predicates to the UserPlatformQuotaDelete builder. +func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{userplatformquota.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/userplatformquota_query.go b/backend/ent/userplatformquota_query.go new file mode 100644 index 00000000..c68fe904 --- /dev/null +++ b/backend/ent/userplatformquota_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities. +type UserPlatformQuotaQuery struct { + config + ctx *QueryContext + order []userplatformquota.OrderOption + inters []Interceptor + predicates []predicate.UserPlatformQuota + withUser *UserQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UserPlatformQuotaQuery builder. +func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first UserPlatformQuota entity from the query. +// Returns a *NotFoundError when no UserPlatformQuota was found. +func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{userplatformquota.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UserPlatformQuota ID from the query. +// Returns a *NotFoundError when no UserPlatformQuota ID was found. +func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{userplatformquota.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UserPlatformQuota entity is found. +// Returns a *NotFoundError when no UserPlatformQuota entities are found. +func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{userplatformquota.Label} + default: + return nil, &NotSingularError{userplatformquota.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query. +// Returns a *NotSingularError when more than one UserPlatformQuota ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{userplatformquota.Label} + default: + err = &NotSingularError{userplatformquota.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UserPlatformQuotaSlice. +func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]() + return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UserPlatformQuota IDs. +func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery { + if _q == nil { + return nil + } + return &UserPlatformQuotaQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]userplatformquota.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...), + withUser: _q.withUser.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UserPlatformQuota.Query(). +// GroupBy(userplatformquota.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UserPlatformQuotaGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = userplatformquota.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.UserPlatformQuota.Query(). +// Select(userplatformquota.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q} + sbuild.label = userplatformquota.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations. +func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !userplatformquota.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) { + var ( + nodes = []*UserPlatformQuota{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withUser != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UserPlatformQuota).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UserPlatformQuota{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UserPlatformQuota) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID) + for i := range fields { + if fields[i] != userplatformquota.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(userplatformquota.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(userplatformquota.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = userplatformquota.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities. +type UserPlatformQuotaGroupBy struct { + selector + build *UserPlatformQuotaQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities. +type UserPlatformQuotaSelect struct { + *UserPlatformQuotaQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v) +} + +func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/userplatformquota_update.go b/backend/ent/userplatformquota_update.go new file mode 100644 index 00000000..4e924cc8 --- /dev/null +++ b/backend/ent/userplatformquota_update.go @@ -0,0 +1,985 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" +) + +// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities. +type UserPlatformQuotaUpdate struct { + config + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// Where appends a list predicates to the UserPlatformQuotaUpdate builder. +func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate { + _u.mutation.ClearDeletedAt() + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetPlatform sets the "platform" field. +func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate { + _u.mutation.SetPlatform(v) + return _u +} + +// SetNillablePlatform sets the "platform" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetPlatform(*v) + } + return _u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetDailyLimitUsd() + _u.mutation.SetDailyLimitUsd(v) + return _u +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyLimitUsd(*v) + } + return _u +} + +// AddDailyLimitUsd adds value to the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddDailyLimitUsd(v) + return _u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearDailyLimitUsd() + return _u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetWeeklyLimitUsd() + _u.mutation.SetWeeklyLimitUsd(v) + return _u +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyLimitUsd(*v) + } + return _u +} + +// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddWeeklyLimitUsd(v) + return _u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearWeeklyLimitUsd() + return _u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetMonthlyLimitUsd() + _u.mutation.SetMonthlyLimitUsd(v) + return _u +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyLimitUsd(*v) + } + return _u +} + +// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddMonthlyLimitUsd(v) + return _u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate { + _u.mutation.ClearMonthlyLimitUsd() + return _u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetDailyUsageUsd() + _u.mutation.SetDailyUsageUsd(v) + return _u +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyUsageUsd(*v) + } + return _u +} + +// AddDailyUsageUsd adds value to the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddDailyUsageUsd(v) + return _u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetWeeklyUsageUsd() + _u.mutation.SetWeeklyUsageUsd(v) + return _u +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyUsageUsd(*v) + } + return _u +} + +// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddWeeklyUsageUsd(v) + return _u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.ResetMonthlyUsageUsd() + _u.mutation.SetMonthlyUsageUsd(v) + return _u +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyUsageUsd(*v) + } + return _u +} + +// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate { + _u.mutation.AddMonthlyUsageUsd(v) + return _u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetDailyWindowStart(v) + return _u +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetDailyWindowStart(*v) + } + return _u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearDailyWindowStart() + return _u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetWeeklyWindowStart(v) + return _u +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetWeeklyWindowStart(*v) + } + return _u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearWeeklyWindowStart() + return _u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate { + _u.mutation.SetMonthlyWindowStart(v) + return _u +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate { + if v != nil { + _u.SetMonthlyWindowStart(*v) + } + return _u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate { + _u.mutation.ClearMonthlyWindowStart() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate { + return _u.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate { + _u.mutation.ClearUser() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) { + if err := _u.defaults(); err != nil { + return 0, err + } + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UserPlatformQuotaUpdate) defaults() error { + if _, ok := _u.mutation.UpdatedAt(); !ok { + if userplatformquota.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserPlatformQuotaUpdate) check() error { + if v, ok := _u.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`) + } + return nil +} + +func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime) + } + if value, ok := _u.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + } + if value, ok := _u.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.DailyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.WeeklyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.MonthlyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + } + if _u.mutation.DailyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + } + if _u.mutation.WeeklyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + } + if _u.mutation.MonthlyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{userplatformquota.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity. +type UserPlatformQuotaUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UserPlatformQuotaMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDeletedAt() + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetPlatform sets the "platform" field. +func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne { + _u.mutation.SetPlatform(v) + return _u +} + +// SetNillablePlatform sets the "platform" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetPlatform(*v) + } + return _u +} + +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetDailyLimitUsd() + _u.mutation.SetDailyLimitUsd(v) + return _u +} + +// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyLimitUsd(*v) + } + return _u +} + +// AddDailyLimitUsd adds value to the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddDailyLimitUsd(v) + return _u +} + +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDailyLimitUsd() + return _u +} + +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetWeeklyLimitUsd() + _u.mutation.SetWeeklyLimitUsd(v) + return _u +} + +// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyLimitUsd(*v) + } + return _u +} + +// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddWeeklyLimitUsd(v) + return _u +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearWeeklyLimitUsd() + return _u +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetMonthlyLimitUsd() + _u.mutation.SetMonthlyLimitUsd(v) + return _u +} + +// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyLimitUsd(*v) + } + return _u +} + +// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddMonthlyLimitUsd(v) + return _u +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearMonthlyLimitUsd() + return _u +} + +// SetDailyUsageUsd sets the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetDailyUsageUsd() + _u.mutation.SetDailyUsageUsd(v) + return _u +} + +// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyUsageUsd(*v) + } + return _u +} + +// AddDailyUsageUsd adds value to the "daily_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddDailyUsageUsd(v) + return _u +} + +// SetWeeklyUsageUsd sets the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetWeeklyUsageUsd() + _u.mutation.SetWeeklyUsageUsd(v) + return _u +} + +// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyUsageUsd(*v) + } + return _u +} + +// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddWeeklyUsageUsd(v) + return _u +} + +// SetMonthlyUsageUsd sets the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.ResetMonthlyUsageUsd() + _u.mutation.SetMonthlyUsageUsd(v) + return _u +} + +// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyUsageUsd(*v) + } + return _u +} + +// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field. +func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne { + _u.mutation.AddMonthlyUsageUsd(v) + return _u +} + +// SetDailyWindowStart sets the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetDailyWindowStart(v) + return _u +} + +// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetDailyWindowStart(*v) + } + return _u +} + +// ClearDailyWindowStart clears the value of the "daily_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearDailyWindowStart() + return _u +} + +// SetWeeklyWindowStart sets the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetWeeklyWindowStart(v) + return _u +} + +// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetWeeklyWindowStart(*v) + } + return _u +} + +// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearWeeklyWindowStart() + return _u +} + +// SetMonthlyWindowStart sets the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne { + _u.mutation.SetMonthlyWindowStart(v) + return _u +} + +// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil. +func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne { + if v != nil { + _u.SetMonthlyWindowStart(*v) + } + return _u +} + +// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field. +func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearMonthlyWindowStart() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne { + return _u.SetUserID(v.ID) +} + +// Mutation returns the UserPlatformQuotaMutation object of the builder. +func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// Where appends a list predicates to the UserPlatformQuotaUpdate builder. +func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UserPlatformQuota entity. +func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) { + if err := _u.defaults(); err != nil { + return nil, err + } + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UserPlatformQuotaUpdateOne) defaults() error { + if _, ok := _u.mutation.UpdatedAt(); !ok { + if userplatformquota.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } + v := userplatformquota.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } + return nil +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserPlatformQuotaUpdateOne) check() error { + if v, ok := _u.mutation.Platform(); ok { + if err := userplatformquota.PlatformValidator(v); err != nil { + return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`) + } + return nil +} + +func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID) + for _, f := range fields { + if !userplatformquota.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != userplatformquota.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime) + } + if value, ok := _u.mutation.Platform(); ok { + _spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value) + } + if value, ok := _u.mutation.DailyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.DailyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.WeeklyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.WeeklyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.MonthlyLimitUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value) + } + if _u.mutation.MonthlyLimitUsdCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64) + } + if value, ok := _u.mutation.DailyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedDailyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.WeeklyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.MonthlyUsageUsd(); ok { + _spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok { + _spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.DailyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value) + } + if _u.mutation.DailyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.WeeklyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value) + } + if _u.mutation.WeeklyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime) + } + if value, ok := _u.mutation.MonthlyWindowStart(); ok { + _spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value) + } + if _u.mutation.MonthlyWindowStartCleared() { + _spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: userplatformquota.UserTable, + Columns: []string{userplatformquota.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &UserPlatformQuota{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{userplatformquota.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/go.mod b/backend/go.mod index caf35ea4..587d5370 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,12 +5,14 @@ go 1.26.3 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/alicebob/miniredis/v2 v2.38.0 github.com/alitto/pond/v2 v2.6.2 github.com/andybalholm/brotli v1.2.0 github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 + github.com/aws/smithy-go v1.24.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 @@ -71,7 +73,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect @@ -158,6 +159,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zclconf/go-cty v1.14.4 // indirect github.com/zclconf/go-cty-yaml v1.1.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 35cfdb03..7735fda2 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -16,6 +16,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw= github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw= +github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw= +github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= @@ -162,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -370,6 +370,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 661e3296..b2764f87 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -643,6 +643,12 @@ type ProxyProbeConfig struct { type BillingConfig struct { CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` + // UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口。 + // 消费点: + // - billing_cache_service.cacheWriteWorker 异步累加 + // - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载 + // 读写两端必须共用同一 TTL,避免缓存生命周期不一致导致 quota 计数漂移。 + UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"` } type CircuitBreakerConfig struct { @@ -1544,6 +1550,7 @@ func setDefaults() { viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30) viper.SetDefault("billing.circuit_breaker.half_open_requests", 3) + viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400) // Turnstile viper.SetDefault("turnstile.required", false) diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index ddeaab02..7b74bafc 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router := gin.New() adminSvc := newStubAdminService() - userHandler := NewUserHandler(adminSvc, nil) + userHandler := NewUserHandler(adminSvc, nil, nil, nil) groupHandler := NewGroupHandler(adminSvc, nil, nil) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc, nil) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 2fef94f1..d00c2259 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -24,6 +24,7 @@ type stubAdminService struct { updatedProxyIDs []int64 updatedProxies []*service.UpdateProxyInput testedProxyIDs []int64 + getUserErr error createAccountErr error updateAccountErr error bulkUpdateAccountErr error @@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi } func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) { + if s.getUserErr != nil { + return nil, s.getUserErr + } for i := range s.users { if s.users[i].ID == id { return &s.users[i], nil diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 36a60c23..3c7fe581 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) } + // Default platform quotas(JSON map) + if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil { + slog.Error("default_platform_quotas_get_failed", "error", err) + } else { + payload.DefaultPlatformQuotas = platformQuotas + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } @@ -637,6 +644,18 @@ type UpdateSettingsRequest struct { // OpenAI fast/flex policy (optional, only updated when provided) OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` + + // 系统全局 platform quota 默认值(整体替换语义:nil = 不修改,non-nil = 整体覆盖)。 + DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"` + + // auth-source 层 platform quota 覆盖(override 语义:nil = 不修改,non-nil = 整体覆盖该 source 的 quota 配置)。 + AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"` + AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"` + AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"` + AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"` + AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"` + AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"` + AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"` } // UpdateSettings 更新系统设置 @@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ + // 系统全局 platform quota 默认值(整体替换语义) + DefaultPlatformQuotas: req.DefaultPlatformQuotas, + RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, @@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { }(), } + // req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值); + // non-nil(含 empty map)表示整体覆盖:empty map = 清空该 source 的所有 quota 配置。 authSourceDefaults := &service.AuthSourceDefaultSettings{ Email: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), @@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas), }, LinuxDo: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), @@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas), }, OIDC: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), @@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas), }, WeChat: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), @@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas), }, GitHub: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance), @@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas), }, Google: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance), @@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas), }, DingTalk: service.ProviderDefaultGrantSettings{ Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance), @@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions), GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind), + PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas), }, ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), } @@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } else if fastPolicy != nil { payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy) } + + // Default platform quotas(JSON map)—— 与 GetSettings 一致,避免保存后响应缺失该字段 + if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil { + slog.Error("default_platform_quotas_get_failed", "error", err) + } else { + payload.DefaultPlatformQuotas = platformQuotas + } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } @@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.RiskControlEnabled != after.RiskControlEnabled { changed = append(changed, "risk_control_enabled") } + // Default platform quotas(JSON map,整体比较) + if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) { + changed = append(changed, service.SettingKeyDefaultPlatformQuotas) + } changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults) return changed } @@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind { changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind") } + // Platform quotas diff:整体替换语义,发单个 JSON key。 + if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) { + changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name)) + } } if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup { changed = append(changed, "force_email_on_third_party_signup") @@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, return result } +// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义: +// nil = 请求未包含该字段(保留 fallback),non-nil(含 empty map)= 整体覆盖。 +// 注意:JSON null 与字段省略等价——两者均反序列化为 nil map,因此都保留旧值; +// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。 +func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting { + if value == nil { + return fallback + } + return value +} + func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { data := make(map[string]any) raw, err := json.Marshal(settings) @@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind + data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas + data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas + data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas + data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas + data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas + data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas + data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup return data @@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo) } return placeholders } + +// equalNullableFloat compares two *float64 values treating nil as a distinct case. +func equalNullableFloat(a, b *float64) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting. +func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 { + if s == nil { + return nil + } + switch win { + case "daily": + return s.DailyLimitUSD + case "weekly": + return s.WeeklyLimitUSD + case "monthly": + return s.MonthlyLimitUSD + } + return nil +} + +// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots. +func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool { + for _, platform := range service.AllowedQuotaPlatforms { + b := before[platform] + a := after[platform] + if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) { + return false + } + if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) { + return false + } + if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) { + return false + } + } + return true +} diff --git a/backend/internal/handler/admin/setting_handler_platform_quota_test.go b/backend/internal/handler/admin/setting_handler_platform_quota_test.go new file mode 100644 index 00000000..273b3442 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_platform_quota_test.go @@ -0,0 +1,188 @@ +//go:build unit + +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) { + five := 5.0 + ten := 10.0 + before := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + after := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + } + + changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{}) + found := false + for _, key := range changed { + if key == service.SettingKeyDefaultPlatformQuotas { + found = true + break + } + } + if !found { + t.Errorf("expected change detection for default platform quotas, got %v", changed) + } +} + +func TestDiffSettings_NoChangeWhenEqual(t *testing.T) { + five := 5.0 + before := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + after := &service.SystemSettings{ + DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + + changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{}) + for _, key := range changed { + if key == service.SettingKeyDefaultPlatformQuotas { + t.Error("equal values should not be detected as changed") + } + } +} + +func TestEqualNullableFloat(t *testing.T) { + five := 5.0 + five2 := 5.0 + ten := 10.0 + cases := []struct { + a, b *float64 + want bool + }{ + {nil, nil, true}, + {&five, nil, false}, + {nil, &five, false}, + {&five, &five2, true}, + {&five, &ten, false}, + } + for _, c := range cases { + if got := equalNullableFloat(c.a, c.b); got != c.want { + t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want) + } + } +} + +func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) { + five := 5.0 + ten := 10.0 + before := map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + } + after := map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + } + if equalPlatformQuotaSettings(before, after) { + t.Error("expected unequal") + } +} + +func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) { + five := 5.0 + ten := 10.0 + before := &service.AuthSourceDefaultSettings{ + LinuxDo: service.ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + }, + } + after := &service.AuthSourceDefaultSettings{ + LinuxDo: service.ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + }, + } + + changed := appendAuthSourceDefaultChanges([]string{}, before, after) + // 改动 B5:整体替换语义,审计 log 发单个 JSON key,而非展开 84 个扁平 key。 + key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo") + found := false + for _, k := range changed { + if k == key { + found = true + break + } + } + if !found { + t.Errorf("expected %q in changed, got %v", key, changed) + } +} + +// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复: +// PUT 发 auth_source_default_email_platform_quotas,GET 能读回相同值(端到端往返)。 +func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) + + // PUT:发 email platform quota(openai monthly=20) + putBody := map[string]any{ + "auth_source_default_email_platform_quotas": map[string]any{ + "openai": map[string]any{ + "monthly": 20, + }, + }, + } + rawBody, err := json.Marshal(putBody) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + handler.UpdateSettings(c) + require.Equal(t, http.StatusOK, rec.Code) + + // 验证 DB 中写入了 JSON key + jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email") + require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB") + + // GET:验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20 + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + handler.GetSettings(c2) + require.Equal(t, http.StatusOK, rec2.Code) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + + emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any) + require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map") + openaiPQ, ok := emailPQ["openai"].(map[string]any) + require.True(t, ok, "expected openai entry in email platform quotas") + monthly, ok := openaiPQ["monthly"].(float64) + require.True(t, ok, "expected monthly to be float64") + require.Equal(t, float64(20), monthly, "expected openai monthly=20") +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index db35472e..32a21692 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -2,10 +2,15 @@ package admin import ( "context" + "errors" + "log/slog" + "math" "strconv" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -20,15 +25,24 @@ type UserWithConcurrency struct { // UserHandler handles admin user management type UserHandler struct { - adminService service.AdminService - concurrencyService *service.ConcurrencyService + adminService service.AdminService + concurrencyService *service.ConcurrencyService + userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view + billingCache service.BillingCache // T17/T18 缓存失效(PUT/POST 路径) } // NewUserHandler creates a new admin user handler -func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler { +func NewUserHandler( + adminService service.AdminService, + concurrencyService *service.ConcurrencyService, + userPlatformQuotaRepo service.UserPlatformQuotaRepository, + billingCache service.BillingCache, +) *UserHandler { return &UserHandler{ - adminService: adminService, - concurrencyService: concurrencyService, + adminService: adminService, + concurrencyService: concurrencyService, + userPlatformQuotaRepo: userPlatformQuotaRepo, + billingCache: billingCache, } } @@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) { } response.Success(c, gin.H{"affected": affected}) } + +// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas +// admin 视角:D14 lazy 归零 + 暴露 *_window_start 调试字段 +func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) { + idStr := c.Param("id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid user id") + return + } + if h.userPlatformQuotaRepo == nil { + response.Success(c, map[string]any{"platform_quotas": []any{}}) + return + } + // 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。 + if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil { + response.ErrorFrom(c, err) + return + } + records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + out := make([]map[string]any, 0, len(records)) + for _, r := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + +// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas. +type UpdateUserPlatformQuotasRequest struct { + Quotas []PlatformQuotaInput `json:"quotas" binding:"required"` +} + +// PlatformQuotaInput 单平台限额输入;limit 字段为 nil 表示不限制。 +type PlatformQuotaInput struct { + Platform string `json:"platform" binding:"required"` + DailyLimitUSD *float64 `json:"daily_limit_usd"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` +} + +// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。 + +// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas +// 全量替换该用户所有平台限额。 +func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) { + if h.userPlatformQuotaRepo == nil { + response.Error(c, 503, "platform quota service not available") + return + } + + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req UpdateUserPlatformQuotasRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.Quotas) > 4 { + response.BadRequest(c, "quotas length must be <= 4") + return + } + seen := make(map[string]struct{}, len(req.Quotas)) + for _, q := range req.Quotas { + if !service.IsAllowedQuotaPlatform(q.Platform) { + response.BadRequest(c, "invalid platform: "+q.Platform) + return + } + if _, dup := seen[q.Platform]; dup { + response.BadRequest(c, "duplicate platform: "+q.Platform) + return + } + seen[q.Platform] = struct{}{} + // daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义: + // nil / not set → 无限额(完全放行) + // 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立) + // > 0 → USD 限额上限 + // 拦截 NaN / ±Inf:客户端可发送超大数(如 1e308 × 2)使 JSON 反序列化得到 +Inf, + // 进入 DB 后 cache check 中 usage >= limit 永不成立,limit 等同失效。 + for _, f := range []struct { + name string + val *float64 + }{ + {"daily_limit_usd", q.DailyLimitUSD}, + {"weekly_limit_usd", q.WeeklyLimitUSD}, + {"monthly_limit_usd", q.MonthlyLimitUSD}, + } { + if f.val == nil { + continue + } + v := *f.val + if v < 0 { + response.BadRequest(c, f.name+" must be >= 0") + return + } + if math.IsNaN(v) || math.IsInf(v, 0) { + response.BadRequest(c, f.name+" must be a finite number") + return + } + } + } + + records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas)) + for _, q := range req.Quotas { + records = append(records, service.UserPlatformQuotaRecord{ + UserID: userID, + Platform: q.Platform, + DailyLimitUSD: q.DailyLimitUSD, + WeeklyLimitUSD: q.WeeklyLimitUSD, + MonthlyLimitUSD: q.MonthlyLimitUSD, + }) + } + + ctx := c.Request.Context() + // 校验用户是否存在,避免 FK 违反导致 500;用户不存在时返回 404。 + if _, err := h.adminService.GetUser(ctx, userID); err != nil { + response.ErrorFrom(c, err) + return + } + // 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。 + // ListByUser 失败不阻断主操作(best-effort),仅记录降级 warn。 + beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if beforeErr != nil { + slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr) + } + if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil { + response.ErrorFrom(c, err) + return + } + + beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords)) + for _, r := range beforeRecords { + beforeByPlatform[r.Platform] = r + } + afterPlatforms := make(map[string]struct{}, len(records)) + for _, r := range records { + afterPlatforms[r.Platform] = struct{}{} + } + changes := make([]map[string]any, 0, len(records)) + for _, r := range records { + entry := map[string]any{ + "platform": r.Platform, + "daily_limit_usd": r.DailyLimitUSD, + "weekly_limit_usd": r.WeeklyLimitUSD, + "monthly_limit_usd": r.MonthlyLimitUSD, + } + if prev, ok := beforeByPlatform[r.Platform]; ok { + entry["before_daily_limit_usd"] = prev.DailyLimitUSD + entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD + entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD + } + changes = append(changes, entry) + } + // 补 removed 条目:before 存在但 after 缺失 = 该平台被软删除。 + // 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。 + for _, prev := range beforeRecords { + if _, kept := afterPlatforms[prev.Platform]; kept { + continue + } + changes = append(changes, map[string]any{ + "platform": prev.Platform, + "removed": true, + "before_daily_limit_usd": prev.DailyLimitUSD, + "before_weekly_limit_usd": prev.WeeklyLimitUSD, + "before_monthly_limit_usd": prev.MonthlyLimitUSD, + }) + } + // before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段; + // false 时所有 entry 都会缺失 before_*_limit_usd,仅有 after 视图。 + slog.Info("admin.quota_updated", + "actor_admin_id", getAdminIDFromContext(c), + "target_user_id", userID, + "platform_count", len(records), + "before_snapshot_available", beforeErr == nil, + "changes", changes) + + // 失效 cache:对全部允许的 platform 统一 invalidate。 + // Trade-off:精确失效(仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser, + // 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素, + // 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。 + if h.billingCache != nil { + for _, p := range service.AllowedQuotaPlatforms { + if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil { + slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err) + } + } + } + + // 返回最新状态 + now := time.Now().UTC() + records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]map[string]any, 0, len(records2)) + for i := range records2 { + out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + +// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset. +type ResetUserPlatformQuotaWindowRequest struct { + Platform string `json:"platform" binding:"required"` + Window string `json:"window" binding:"required"` +} + +var allowedWindowsForQuotaReset = map[string]struct{}{ + "daily": {}, + "weekly": {}, + "monthly": {}, +} + +// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset +// 立即归零指定 (platform, window) 的用量并更新 window_start。 +func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) { + if h.userPlatformQuotaRepo == nil { + response.Error(c, 503, "platform quota service not available") + return + } + + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req ResetUserPlatformQuotaWindowRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !service.IsAllowedQuotaPlatform(req.Platform) { + response.BadRequest(c, "invalid platform: "+req.Platform) + return + } + if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok { + response.BadRequest(c, "invalid window: "+req.Window) + return + } + + ctx := c.Request.Context() + // 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。 + if _, err := h.adminService.GetUser(ctx, userID); err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil { + if errors.Is(err, service.ErrUserPlatformQuotaNotFound) { + response.NotFound(c, "user platform quota not found") + return + } + response.ErrorFrom(c, err) + return + } + + slog.Info("admin.quota_window_reset", + "actor_admin_id", getAdminIDFromContext(c), + "target_user_id", userID, + "platform", req.Platform, + "window", req.Window) + + if h.billingCache != nil { + if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil { + slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err) + } + } + + records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]map[string]any, 0, len(records)) + for i := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go index bfba2408..a9fdc48a 100644 --- a/backend/internal/handler/admin/user_handler_activity_test.go +++ b/backend/internal/handler/admin/user_handler_activity_test.go @@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) { UpdatedAt: lastLoginAt, }, } - handler := NewUserHandler(adminSvc, nil) + handler := NewUserHandler(adminSvc, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) { UpdatedAt: lastLoginAt, }, } - handler := NewUserHandler(adminSvc, nil) + handler := NewUserHandler(adminSvc, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) diff --git a/backend/internal/handler/admin/user_platform_quota_admin_test.go b/backend/internal/handler/admin/user_platform_quota_admin_test.go new file mode 100644 index 00000000..fe33d36c --- /dev/null +++ b/backend/internal/handler/admin/user_platform_quota_admin_test.go @@ -0,0 +1,301 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository,捕获 UpsertForUser 调用。 +type upsertCapturingQuotaRepo struct { + service.UserPlatformQuotaRepository + listRecords []service.UserPlatformQuotaRecord + listErr error + upsertCalls []upsertCall + upsertErr error + resetCalls []resetCall + resetErr error +} + +type upsertCall struct { + userID int64 + records []service.UserPlatformQuotaRecord +} +type resetCall struct { + userID int64 + platform string + window string + newStart time.Time +} + +func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return r.listRecords, r.listErr +} +func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + cloned := make([]service.UserPlatformQuotaRecord, len(records)) + copy(cloned, records) + r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned}) + return r.upsertErr +} +func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error { + r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart}) + return r.resetErr +} + +// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。 +type billingCacheStub struct { + service.BillingCache + deleteCalls []deleteCall + deleteErr error +} + +type deleteCall struct { + userID int64 + platform string +} + +func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error { + b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform}) + return b.deleteErr +} + +func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler { + return &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: newStubAdminService(), + } +} + +func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + c.Params = []gin.Param{{Key: "id", Value: "42"}} + return c, w +} + +func TestUpdateUserPlatformQuotas_Success(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + + body := `{"quotas":[ + {"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0}, + {"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if len(repo.upsertCalls) != 1 { + t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls)) + } + if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 { + t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0]) + } + // 缓存失效:请求中 2 个 platform + 软删除的 2 个 platform(gemini, antigravity)= 4 次 + if len(cache.deleteCalls) != 4 { + t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls) + } +} + +func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[ + {"platform":"anthropic","daily_limit_usd":1}, + {"platform":"anthropic","daily_limit_usd":2} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + body := `{"quotas":[ + {"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"} + ]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) { + repo := &upsertCapturingQuotaRepo{ + listRecords: []service.UserPlatformQuotaRecord{ + {UserID: 42, Platform: "anthropic"}, + }, + } + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if !strings.Contains(w.Body.String(), `"platform_quotas"`) { + t.Errorf("response should contain platform_quotas array: %s", w.Body.String()) + } +} + +// ───────── T4: Reset 测试 ───────── + +func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) { + t.Helper() + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + c.Request = req + c.Params = []gin.Param{{Key: "id", Value: "42"}} + return c, w +} + +func TestResetUserPlatformQuotaWindow_Success(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + body := `{"platform":"anthropic","window":"daily"}` + c, w := postReq(t, body) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if len(repo.resetCalls) != 1 { + t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls)) + } + if repo.resetCalls[0].userID != 42 || + repo.resetCalls[0].platform != "anthropic" || + repo.resetCalls[0].window != "daily" { + t.Errorf("unexpected reset call: %+v", repo.resetCalls[0]) + } + if len(cache.deleteCalls) != 1 || + cache.deleteCalls[0].userID != 42 || + cache.deleteCalls[0].platform != "anthropic" { + t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls) + } +} + +func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) { + h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"unknown","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) { + // handler 检查 service.ErrUserPlatformQuotaNotFound(由 adapter 包装而来) + repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound} + h := buildTestHandler(repo, &billingCacheStub{}) + c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) { + repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")} + cache := &billingCacheStub{} + h := buildTestHandler(repo, cache) + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code < 500 { + t.Errorf("expected 5xx, got %d", w.Code) + } + // 返回 JSON 错误响应 + var body2 map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil { + t.Errorf("expected JSON error body, got: %s", w.Body.String()) + } +} + +func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + adminSvc := newStubAdminService() + adminSvc.getUserErr = service.ErrUserNotFound + h := &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: adminSvc, + } + body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}` + c, w := putReq(t, body) + h.UpdateUserPlatformQuotas(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) { + repo := &upsertCapturingQuotaRepo{} + cache := &billingCacheStub{} + adminSvc := newStubAdminService() + adminSvc.getUserErr = service.ErrUserNotFound + h := &UserHandler{ + userPlatformQuotaRepo: repo, + billingCache: cache, + adminService: adminSvc, + } + c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`) + h.ResetUserPlatformQuotaWindow(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/admin/user_platform_quotas_handler_test.go b/backend/internal/handler/admin/user_platform_quotas_handler_test.go new file mode 100644 index 00000000..1927a6e8 --- /dev/null +++ b/backend/internal/handler/admin/user_platform_quotas_handler_test.go @@ -0,0 +1,124 @@ +//go:build unit + +package admin + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type fakeQuotaRepoForAdmin struct { + service.UserPlatformQuotaRepository + records []service.UserPlatformQuotaRecord + err error +} + +func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return f.records, f.err +} + +func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context { + c, _ := gin.CreateTestContext(w) + req, _ := http.NewRequest(http.MethodGet, "/", nil) + c.Request = req + return c +} + +func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) { + start := time.Now().Add(-1 * time.Hour) + repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{ + UserID: 99, Platform: "anthropic", + DailyUsageUSD: 1.0, DailyWindowStart: &start, + }}} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "99"}} + h.GetUserPlatformQuotas(c) + + if w.Code != 200 { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), `"daily_window_start"`) { + t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String()) + } +} + +func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "abc"}} + h.GetUserPlatformQuotas(c) + if w.Code < 400 || w.Code >= 500 { + t.Errorf("invalid id should yield 4xx, got %d", w.Code) + } +} + +func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) { + repo := &fakeQuotaRepoForAdmin{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "99"}} + h.GetUserPlatformQuotas(c) + if w.Code != 200 { + t.Errorf("empty list should be 200, got %d", w.Code) + } + var body map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("response is not valid JSON: %v", err) + } + data, ok := body["data"].(map[string]any) + if !ok { + t.Fatalf("response missing data object: %v", body) + } + quotas, ok := data["platform_quotas"].([]any) + if !ok { + t.Fatalf("data.platform_quotas missing or wrong type: %v", data) + } + if len(quotas) != 0 { + t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas) + } +} + +func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "1"}} + h.GetUserPlatformQuotas(c) + if w.Code != 200 { + t.Errorf("nil repo should return 200 empty, got %d", w.Code) + } +} + +// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404 +// (与 PUT / POST reset 端点行为一致;review fix:原实现返回空数组会让 admin 界面误判用户存在) +func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + repo := &fakeQuotaRepoForAdmin{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c := newAdminQuotaTestContext(w) + c.Params = []gin.Param{{Key: "id", Value: "999"}} + h.GetUserPlatformQuotas(c) + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 4081b9e4..70fb160a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates ( nil, options.defaultSubAssigner, affiliateService, + nil, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go index f1c6d87d..922bf9c3 100644 --- a/backend/internal/handler/auth_session_revocation_test.go +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) handler := &AuthHandler{authService: authService} recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index b3c7786d..e291b7e7 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, nil, nil, nil, + nil, ) return &AuthHandler{ diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index fac60573..eecf98ac 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -3,6 +3,8 @@ package dto import ( "encoding/json" "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" ) // CustomMenuItem represents a user-configured custom menu entry. @@ -246,6 +248,9 @@ type SystemSettings struct { // OpenAI fast/flex policy OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"` + + // 系统全局默认平台配额(key = platform,nil/缺省 = 不限制) + DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 51c2d94e..11915aa7 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/http" "strconv" "strings" @@ -250,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 2. 【新增】Wait后二次检查余额/订阅 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -508,10 +509,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, @@ -808,7 +811,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil { status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { c.Header("Retry-After", strconv.Itoa(retryAfter)) @@ -900,10 +903,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ParsedRequest: parsedReq, + QuotaPlatform: quotaPlatform, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, @@ -1582,7 +1587,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { c.Header("Retry-After", strconv.Itoa(retryAfter)) @@ -1830,6 +1835,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter c.JSON(http.StatusOK, response) } +// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算 +// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。 +func extractQuotaResetSeconds(err error) int { + const fallback = 60 + appErr := pkgerrors.FromError(err) + if appErr == nil { + return fallback + } + raw, ok := appErr.Metadata["window_resets_at"] + if !ok || raw == "" { + return fallback + } + resetAt, parseErr := time.Parse(time.RFC3339, raw) + if parseErr != nil { + logger.L().With( + zap.String("component", "handler.gateway.billing"), + zap.String("raw", raw), + zap.Error(parseErr), + ).Warn("quota.invalid_window_resets_at_format") + return fallback + } + secs := time.Until(resetAt).Seconds() + if secs <= 0 { + // reset 时间已过:cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避, + // 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。 + return fallback + } + return int(math.Ceil(secs)) +} + func billingErrorDetails(err error) (status int, code, message string, retryAfter int) { if errors.Is(err, service.ErrBillingServiceUnavailable) { msg := pkgerrors.Message(err) @@ -1857,6 +1892,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte retrySeconds := 60 - int(time.Now().Unix()%60) return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds } + if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) || + errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) || + errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) { + // 与 RPM 超限一致映射 429 + Retry-After,让 SDK 自动退避(而非 403 直接失败)。 + // 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。 + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err) + } msg := pkgerrors.Message(err) if msg == "" { logger.L().With( diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go index e8a88802..eaf48d1e 100644 --- a/backend/internal/handler/gateway_handler_billing_error_test.go +++ b/backend/internal/handler/gateway_handler_billing_error_test.go @@ -1,8 +1,10 @@ package handler import ( + "errors" "net/http" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" @@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) { require.Equal(t, "billing_error", code) require.NotEmpty(t, msg) } + +func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) { + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339), + }) + got := extractQuotaResetSeconds(err) + if got < 10 || got > 11 { + t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got) + } +} + +func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) { + if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 { + t.Errorf("T20: got %d, want 60 fallback", got) + } +} + +func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) { + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": "not-a-time", + }) + if got := extractQuotaResetSeconds(err); got != 60 { + t.Errorf("T21: got %d, want 60 fallback", got) + } +} + +func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) { + // 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s: + // 1 秒会导致客户端立即重试仍触发限额的退避循环; + // 60s 让客户端按常规节奏退避,cache/DB 自愈期间不会反复打抖。 + err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339), + }) + if got := extractQuotaResetSeconds(err); got != 60 { + t.Errorf("T22: got %d, want 60 (fallback on past reset)", got) + } +} + +func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) { + // quota 超限映射 429 + Retry-After(RFC 6585 / 与 RPM 一致), + // 让 SDK(OpenAI 兼容客户端等)能按 Retry-After 自动退避。 + // 旧实现用 403 导致客户端不退避直接报错。 + // 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。 + cases := []struct { + name string + err error + }{ + {"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + {"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + {"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{ + "window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339), + })}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + status, code, _, retryAfter := billingErrorDetails(tc.err) + if status != http.StatusTooManyRequests { + t.Errorf("status = %d, want 429", status) + } + if code != "rate_limit_exceeded" { + t.Errorf("code = %q, want rate_limit_exceeded", code) + } + if retryAfter < 3599 || retryAfter > 3601 { + t.Errorf("retryAfter = %d, want ~3600", retryAfter) + } + }) + } +} diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 9a091fcd..acbdc261 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { } // 2. Re-check billing - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e1a5b723..6a083f31 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { } // 2. Re-check billing - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 57554cf9..09b20722 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // channelService nil, // resolver nil, // balanceNotifyService + nil, // userPlatformQuotaRepo ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go index af52ae23..78b07a1a 100644 --- a/backend/internal/handler/gateway_models_test.go +++ b/backend/internal/handler/gateway_models_test.go @@ -43,7 +43,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand gatewayService: service.NewGatewayService( repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ), } } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 665c0677..27ea4404 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 2) billing eligibility check (after wait) - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) status, _, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { requestPayloadHash := service.HashUsageRequestPayload(body) inboundEndpoint := GetInboundEndpoint(c) upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, + QuotaPlatform: quotaPlatform, APIKey: apiKey, User: apiKey.User, Account: account, diff --git a/backend/internal/handler/image_concurrency_limiter_test.go b/backend/internal/handler/image_concurrency_limiter_test.go index 20147f16..723c26d0 100644 --- a/backend/internal/handler/image_concurrency_limiter_test.go +++ b/backend/internal/handler/image_concurrency_limiter_test.go @@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t * h := &OpenAIGatewayHandler{ gatewayService: &service.OpenAIGatewayService{}, - billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}), + billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil), apiKeyService: &service.APIKeyService{}, concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 4d523dba..17f0d47e 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 5464d654..88ece8e7 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // 2. Re-check billing eligibility after wait - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { @@ -1235,7 +1235,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) subscription, _ := middleware2.GetSubscriptionFromContext(c) - if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") return diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 49b7b9ec..b304640e 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -1201,7 +1201,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU }, nil, nil, nil) } - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil) gatewaySvc := service.NewOpenAIGatewayService( accountRepo, usageRepo, @@ -1223,6 +1223,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU channelSvc, nil, nil, + nil, // userPlatformQuotaRepo ) cache := &concurrencyCacheMock{ diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index e6c37272..bbb08014 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { defer userReleaseFunc() } - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil { reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) status, code, message, retryAfter := billingErrorDetails(err) if retryAfter > 0 { diff --git a/backend/internal/handler/quotaview/helpers.go b/backend/internal/handler/quotaview/helpers.go new file mode 100644 index 00000000..ff4e44a1 --- /dev/null +++ b/backend/internal/handler/quotaview/helpers.go @@ -0,0 +1,104 @@ +// Package quotaview provides shared quota response helpers for user and admin handlers. +// Extracted to avoid import cycles between handler and handler/admin packages. +package quotaview + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB)。 +// includeWindowStart=true 时输出 *_window_start 字段(admin 视角调试用) +func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any { + daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart) + weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart) + monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart) + out := map[string]any{ + "platform": r.Platform, + "daily_usage_usd": daily.usage, + "daily_limit_usd": daily.limit, + "daily_window_resets_at": daily.resetsAt, + "weekly_usage_usd": weekly.usage, + "weekly_limit_usd": weekly.limit, + "weekly_window_resets_at": weekly.resetsAt, + "monthly_usage_usd": monthly.usage, + "monthly_limit_usd": monthly.limit, + "monthly_window_resets_at": monthly.resetsAt, + } + if includeWindowStart { + out["daily_window_start"] = daily.windowStart + out["weekly_window_start"] = weekly.windowStart + out["monthly_window_start"] = monthly.windowStart + } + return out +} + +type windowSlice struct { + usage float64 + limit *float64 + resetsAt *string + windowStart *string +} + +func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice { + out := windowSlice{usage: usage, limit: limit} + if expired { + out.usage = 0 + out.resetsAt = nil + } else if start != nil { + s := nextReset.Format(time.RFC3339) + out.resetsAt = &s + } + if includeStart && start != nil { + s := start.Format(time.RFC3339) + out.windowStart = &s + } + return out +} + +// NeedsDailyReset 判断日窗口是否已过期:start 早于「全局时区当天 0 点」即过期。 +// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。 +func NeedsDailyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return start.Before(timezone.StartOfDay(now)) +} + +func NeedsWeeklyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return start.Before(timezone.StartOfWeek(now)) +} + +// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。 +func NeedsMonthlyReset(start *time.Time, now time.Time) bool { + if start == nil { + return false + } + return now.Sub(*start) >= 30*24*time.Hour +} + +func nextDailyResetTime(now time.Time) time.Time { + return timezone.StartOfDay(now).AddDate(0, 0, 1) +} + +func nextWeeklyResetTime(now time.Time) time.Time { + return timezone.StartOfWeek(now).AddDate(0, 0, 7) +} + +// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。 +// 语义: +// - start != nil → 返回 start + 30d(与 billing_cache_service.nextMonthlyResetFrom 一致) +// - start == nil → 退化为 now + 30d(保留旧行为,避免 nil 崩溃) +// +// 导出(首字母大写)以允许测试直接调用。 +func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time { + if start == nil { + return now.Add(30 * 24 * time.Hour) + } + return start.Add(30 * 24 * time.Hour) +} diff --git a/backend/internal/handler/quotaview/helpers_test.go b/backend/internal/handler/quotaview/helpers_test.go new file mode 100644 index 00000000..ca2fcc4e --- /dev/null +++ b/backend/internal/handler/quotaview/helpers_test.go @@ -0,0 +1,133 @@ +package quotaview + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TestNextMonthlyResetTimeFrom_FromStart 验证:start 已知时返回 start+30d,不随 now 漂移。 +func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) { + t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + now := t0.Add(15 * 24 * time.Hour) // t0 + 15d + want := t0.Add(30 * 24 * time.Hour) // t0 + 30d + + got := NextMonthlyResetTimeFrom(&t0, now) + if !got.Equal(want) { + t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got) + } +} + +// TestNextMonthlyResetTimeFrom_NilStart 验证:start=nil 时退化为 now+30d(不 panic)。 +func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) { + now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC) + want := now.Add(30 * 24 * time.Hour) + + got := NextMonthlyResetTimeFrom(nil, now) + if !got.Equal(want) { + t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got) + } +} + +// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证: +// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record, +// monthly_window_resets_at 始终等于 windowStart+30d,不随 now 漂移。 +func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) { + windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339) + + r := service.UserPlatformQuotaRecord{ + Platform: "openai", + MonthlyUsageUSD: 5.0, + MonthlyWindowStart: &windowStart, + } + + // 第一次调用:now = windowStart + 5d + now1 := windowStart.Add(5 * 24 * time.Hour) + out1 := LazyZeroQuotaForResponse(r, now1, false) + resetsAt1, ok1 := out1["monthly_window_resets_at"] + if !ok1 || resetsAt1 == nil { + t.Fatal("first call: monthly_window_resets_at should be set for active window") + } + s1, ok := resetsAt1.(*string) + if !ok || s1 == nil { + t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1) + } + if *s1 != wantResetsAt { + t.Errorf("first call: want %s, got %s", wantResetsAt, *s1) + } + + // 第二次调用:now = windowStart + 10d(不同 now,但 resetsAt 应不变) + now2 := windowStart.Add(10 * 24 * time.Hour) + out2 := LazyZeroQuotaForResponse(r, now2, false) + resetsAt2, ok2 := out2["monthly_window_resets_at"] + if !ok2 || resetsAt2 == nil { + t.Fatal("second call: monthly_window_resets_at should be set for active window") + } + s2, ok := resetsAt2.(*string) + if !ok || s2 == nil { + t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2) + } + if *s2 != wantResetsAt { + t.Errorf("second call: want %s, got %s", wantResetsAt, *s2) + } + + // 两次结果必须相等 + if *s1 != *s2 { + t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2) + } +} + +// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。 +func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + // now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08(北京 5/26) + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) + + // start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08(北京 5/25)→ 应判定为过期 + startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) + if !NeedsDailyReset(&startPrevBeijingDay, now) { + t.Error("上一个北京日的窗口应判定为过期") + } + + // start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08(北京 5/26 同日)→ 不应过期 + startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC) + if NeedsDailyReset(&startSameBeijingDay, now) { + t.Error("同一北京日的窗口不应判定为过期") + } +} + +// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。 +func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 + want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00 + if got := nextDailyResetTime(now); !got.Equal(want) { + t.Errorf("nextDailyResetTime = %v, want %v", got, want) + } +} + +// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。 +func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) { + if err := timezone.Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = timezone.Init("UTC") }) + + // 北京 2026-05-26(周二)→ 下周一是 2026-06-01 + now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二 + want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location()) + if got := nextWeeklyResetTime(now); !got.Equal(want) { + t.Errorf("nextWeeklyResetTime = %v, want %v", got, want) + } +} diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 95cb1482..d4b10b2b 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -3,8 +3,10 @@ package handler import ( "context" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -14,11 +16,12 @@ import ( // UserHandler handles user-related requests type UserHandler struct { - userService *service.UserService - authService *service.AuthService - emailService *service.EmailService - emailCache service.EmailCache - affiliateService *service.AffiliateService + userService *service.UserService + authService *service.AuthService + emailService *service.EmailService + emailCache service.EmailCache + affiliateService *service.AffiliateService + userPlatformQuotaRepo service.UserPlatformQuotaRepository } // NewUserHandler creates a new UserHandler @@ -28,16 +31,44 @@ func NewUserHandler( emailService *service.EmailService, emailCache service.EmailCache, affiliateService *service.AffiliateService, + userPlatformQuotaRepo service.UserPlatformQuotaRepository, ) *UserHandler { return &UserHandler{ - userService: userService, - authService: authService, - emailService: emailService, - emailCache: emailCache, - affiliateService: affiliateService, + userService: userService, + authService: authService, + emailService: emailService, + emailCache: emailCache, + affiliateService: affiliateService, + userPlatformQuotaRepo: userPlatformQuotaRepo, } } +// GetMyPlatformQuotas GET /user/platform-quotas +// 返回当前 JWT 用户的 platform quota 状态。 +// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null(不写 DB) +func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.userPlatformQuotaRepo == nil { + response.Success(c, map[string]any{"platform_quotas": []any{}}) + return + } + records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + now := time.Now().UTC() + out := make([]map[string]any, 0, len(records)) + for _, r := range records { + out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false)) + } + response.Success(c, map[string]any{"platform_quotas": out}) +} + // ChangePasswordRequest represents the change password request payload type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index ffca86dc..41647802 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } -func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil @@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) recorder := httptest.NewRecorder() @@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -272,20 +276,20 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { AvatarURL: "https://cdn.example.com/linuxdo.png", AvatarSource: "remote_url", }, - identities: []service.UserAuthIdentityRecord{ - { - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: "linuxdo-subject-21", - VerifiedAt: &verifiedAt, - Metadata: map[string]any{ - "username": "linuxdo-handle", - "avatar_url": "https://cdn.example.com/linuxdo.png", - }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + "avatar_url": "https://cdn.example.com/linuxdo.png", }, }, - } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) recorder := httptest.NewRecorder() @@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t * ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) recorder := httptest.NewRecorder() @@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/user_platform_quotas_handler_test.go b/backend/internal/handler/user_platform_quotas_handler_test.go new file mode 100644 index 00000000..96d13fd3 --- /dev/null +++ b/backend/internal/handler/user_platform_quotas_handler_test.go @@ -0,0 +1,212 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler/quotaview" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集 +type fakeQuotaRepoForUserHandler struct { + service.UserPlatformQuotaRepository + records []service.UserPlatformQuotaRecord +} + +func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) { + return f.records, nil +} + +func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) { + repo := &fakeQuotaRepoForUserHandler{records: nil} + h := &UserHandler{userPlatformQuotaRepo: repo} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + h.GetMyPlatformQuotas(c) + if w.Code != 200 { + t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String()) + } + var body struct { + Code int `json:"code"` + Data struct { + PlatformQuotas []any `json:"platform_quotas"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String()) + } + if body.Code != 0 { + t.Errorf("expected code=0, got %d", body.Code) + } + if body.Data.PlatformQuotas == nil { + // nil 和 empty slice 均视为可接受(JSON 可能序列化为 null 或 []) + // 此断言只验证 HTTP 200 + code=0 即可 + } +} + +func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) { + pastStart := time.Now().UTC().AddDate(0, 0, -2) + daily := 5.0 + repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{ + UserID: 42, + Platform: "anthropic", + DailyLimitUSD: &daily, + DailyUsageUSD: 3.0, + DailyWindowStart: &pastStart, + }}} + h := &UserHandler{userPlatformQuotaRepo: repo} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + h.GetMyPlatformQuotas(c) + + if w.Code != 200 { + t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String()) + } + + // 解析 response,验证过期 daily 的 usage_usd=0 且 window_resets_at=null + body := w.Body.String() + if !strings.Contains(body, `"daily_usage_usd":0`) { + t.Errorf("expected daily_usage_usd:0 in body, got: %s", body) + } + if !strings.Contains(body, `"daily_window_resets_at":null`) { + t.Errorf("expected daily_window_resets_at:null in body, got: %s", body) + } +} + +func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99}) + h.GetMyPlatformQuotas(c) + if w.Code != 200 { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) { + h := &UserHandler{userPlatformQuotaRepo: nil} + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil) + // 不设置 auth subject + h.GetMyPlatformQuotas(c) + if w.Code != 401 { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) { + start := time.Now().UTC().Add(-1 * time.Hour) + r := service.UserPlatformQuotaRecord{ + Platform: "anthropic", + DailyUsageUSD: 1.0, + DailyWindowStart: &start, + } + out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false) + if _, ok := out["daily_window_start"]; ok { + t.Error("user view should not include daily_window_start") + } +} + +func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) { + start := time.Now().UTC().Add(-1 * time.Hour) + r := service.UserPlatformQuotaRecord{ + Platform: "anthropic", + DailyWindowStart: &start, + } + out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true) + if _, ok := out["daily_window_start"]; !ok { + t.Error("admin view should include daily_window_start") + } +} + +func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) { + // 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径 + now := time.Now() + today := timezone.StartOfDay(now) + usage := 2.5 + r := service.UserPlatformQuotaRecord{ + Platform: "openai", + DailyUsageUSD: usage, + DailyWindowStart: &today, + } + out := quotaview.LazyZeroQuotaForResponse(r, now, false) + if out["daily_usage_usd"] != usage { + t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"]) + } + // 活跃窗口应有 resets_at(非 nil) + if out["daily_window_resets_at"] == nil { + t.Error("active window should have daily_window_resets_at set") + } +} + +func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsDailyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need reset") + } +} + +func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) { + old := time.Now().UTC().AddDate(0, 0, -1) + if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) { + t.Error("yesterday start should need daily reset") + } +} + +func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need weekly reset") + } +} + +func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) { + if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) { + t.Error("nil start should not need monthly reset") + } +} + +// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义(C-NEW-1)。 +func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) { + start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期 + if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) { + t.Error("31 days ago should need monthly reset (30-day rolling)") + } +} + +func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) { + start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效 + if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) { + t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)") + } +} + +// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。 +func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) { + // 窗口起始 4 月 20 日;5 月 1 日仅过了 11 天,不足 30 天,不应重置 + windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + if quotaview.NeedsMonthlyReset(&windowStart, now) { + t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)") + } +} diff --git a/backend/internal/pkg/timezone/timezone_test.go b/backend/internal/pkg/timezone/timezone_test.go index ac9cdde6..610b1bb9 100644 --- a/backend/internal/pkg/timezone/timezone_test.go +++ b/backend/internal/pkg/timezone/timezone_test.go @@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) { _ = Now() _ = StartOfDay(Now()) } + +func TestStartOfWeek_Boundaries(t *testing.T) { + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init: %v", err) + } + t.Cleanup(func() { _ = Init("UTC") }) + + loc := Location() + wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一 + + cases := []struct { + name string + in time.Time + }{ + {"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)}, + {"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)}, + {"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := StartOfWeek(c.in); !got.Equal(wantMon) { + t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon) + } + }) + } +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 6922b4c8..60dae954 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6 key := billingRateLimitKey(keyID) return c.rdb.Del(ctx, key).Err() } + +// ============================================ +// user × platform quota 缓存 +// ============================================ + +// userPlatformQuotaCacheKey 构造 Redis key +func userPlatformQuotaCacheKey(userID int64, platform string) string { + return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform) +} + +func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) { + key := userPlatformQuotaCacheKey(userID, platform) + fields := []string{ + "daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version", + "daily_limit", "weekly_limit", "monthly_limit", + "daily_window_start", "weekly_window_start", "monthly_window_start", + } + vals, err := c.rdb.HMGet(ctx, key, fields...).Result() + if err != nil { + return nil, false, err + } + // 前4个全为nil → key 不存在 + if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil { + return nil, false, nil + } + parseFloat := func(v any) float64 { + if v == nil { + return 0 + } + s, ok := v.(string) + if !ok { + return 0 + } + f, _ := strconv.ParseFloat(s, 64) + return f + } + parseFloatPtr := func(v any) *float64 { + if v == nil { + return nil + } + s, ok := v.(string) + if !ok || s == "" { + return nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil + } + return &f + } + parseTimePtr := func(v any) *time.Time { + if v == nil { + return nil + } + s, ok := v.(string) + if !ok || s == "" { + return nil + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil + } + t := time.Unix(n, 0).UTC() + return &t + } + parseInt64 := func(v any) int64 { + if v == nil { + return 0 + } + s, ok := v.(string) + if !ok { + return 0 + } + n, _ := strconv.ParseInt(s, 10, 64) + return n + } + return &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: parseFloat(vals[0]), + WeeklyUsageUSD: parseFloat(vals[1]), + MonthlyUsageUSD: parseFloat(vals[2]), + Version: parseInt64(vals[3]), + SchemaVersion: parseInt64(vals[4]), + DailyLimitUSD: parseFloatPtr(vals[5]), + WeeklyLimitUSD: parseFloatPtr(vals[6]), + MonthlyLimitUSD: parseFloatPtr(vals[7]), + DailyWindowStart: parseTimePtr(vals[8]), + WeeklyWindowStart: parseTimePtr(vals[9]), + MonthlyWindowStart: parseTimePtr(vals[10]), + }, true, nil +} + +func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + key := userPlatformQuotaCacheKey(userID, platform) + pipe := c.rdb.TxPipeline() + + // 浮点可空字段:nil → 空字符串(读取时 parseFloatPtr 返回 nil,表示无限额) + fmtFloatPtr := func(p *float64) string { + if p == nil { + return "" + } + return strconv.FormatFloat(*p, 'f', -1, 64) + } + // time.Time 可空字段:nil → 空字符串;有值 → unix 秒 + fmtTimePtr := func(p *time.Time) string { + if p == nil { + return "" + } + return strconv.FormatInt(p.Unix(), 10) + } + + pipe.HSet(ctx, key, + "daily_usage", entry.DailyUsageUSD, + "weekly_usage", entry.WeeklyUsageUSD, + "monthly_usage", entry.MonthlyUsageUSD, + "version", entry.Version, + "schema_version", entry.SchemaVersion, + "daily_limit", fmtFloatPtr(entry.DailyLimitUSD), + "weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD), + "monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD), + "daily_window_start", fmtTimePtr(entry.DailyWindowStart), + "weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart), + "monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart), + ) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err() +} + +// updateUserPlatformQuotaUsageScript 缓存累加:EXISTS + schema_version 双重守卫。 +// 旧版 entry(schema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后 +// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。 +// key 不存在同样跳过(由下次 SetCache 重建)。 +// KEYS[1] = hash key +// ARGV[1] = cost (string float) +// ARGV[2] = ttl seconds +// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1) +const updateUserPlatformQuotaUsageScript = ` +if redis.call("EXISTS", KEYS[1]) == 0 then + return 0 +end +local ver = redis.call("HGET", KEYS[1], "schema_version") +if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then + return 0 +end +redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1]) +redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1]) +redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1]) +redis.call("HINCRBY", KEYS[1], "version", 1) +redis.call("EXPIRE", KEYS[1], ARGV[2]) +return 1 +` + +func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + key := userPlatformQuotaCacheKey(userID, platform) + _, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key}, + strconv.FormatFloat(cost, 'f', -1, 64), + int(ttl.Seconds()), + service.UserPlatformQuotaCacheSchemaV1, + ).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} diff --git a/backend/internal/repository/billing_cache_user_platform_quota_test.go b/backend/internal/repository/billing_cache_user_platform_quota_test.go new file mode 100644 index 00000000..8d49fd31 --- /dev/null +++ b/backend/internal/repository/billing_cache_user_platform_quota_test.go @@ -0,0 +1,134 @@ +//go:build unit + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return &billingCache{rdb: rdb}, mr +} + +func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) { + c, _ := newMiniRedisCache(t) + entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic") + if err != nil { + t.Fatal(err) + } + if ok || entry != nil { + t.Errorf("expected miss, got ok=%v entry=%v", ok, entry) + } +} + +func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + dailyLimit := 20.0 + ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC) + in := &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 1.5, + WeeklyUsageUSD: 3.0, + MonthlyUsageUSD: 10.0, + Version: 7, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: &dailyLimit, + DailyWindowStart: &ts, + } + if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil { + t.Fatal(err) + } + got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if err != nil || !ok { + t.Fatalf("get: ok=%v err=%v", ok, err) + } + if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 { + t.Errorf("got = %+v, want %+v", got, in) + } + if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 { + t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1) + } + if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit { + t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit) + } + if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) { + t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts) + } +} + +func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + in := &service.UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 1.0, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + // DailyLimitUSD nil → 无限额 + } + if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil { + t.Fatal(err) + } + got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if err != nil || !ok { + t.Fatalf("get: ok=%v err=%v", ok, err) + } + if got.DailyLimitUSD != nil { + t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD) + } +} + +func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) { + c, _ := newMiniRedisCache(t) + if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil { + t.Fatal(err) + } + _, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai") + if ok { + t.Error("expected key absent after no-op incr") + } +} + +func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + // SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。 + _ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{ + Version: 1, + SchemaVersion: service.UserPlatformQuotaCacheSchemaV1, + }, time.Minute) + if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil { + t.Fatal(err) + } + if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil { + t.Fatal(err) + } + got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 { + t.Errorf("got %+v, want daily/weekly/monthly=0.75", got) + } + if got.Version != 3 { + t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version) + } +} + +func TestUserPlatformQuotaCache_Delete(t *testing.T) { + c, _ := newMiniRedisCache(t) + ctx := context.Background() + _ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute) + if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil { + t.Fatal(err) + } + _, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai") + if ok { + t.Error("expected miss after delete") + } +} diff --git a/backend/internal/repository/user_platform_quota_adapter_test.go b/backend/internal/repository/user_platform_quota_adapter_test.go new file mode 100644 index 00000000..a55d2e9c --- /dev/null +++ b/backend/internal/repository/user_platform_quota_adapter_test.go @@ -0,0 +1,91 @@ +//go:build unit + +package repository + +import ( + "context" + "errors" + "testing" + "time" +) + +type fakeRepoForAdapter struct { + upsertCalledWith []UserPlatformQuotaRecord + upsertCalledUserID int64 + upsertErr error + resetCalledWith [4]any // userID, platform, window, newStart + resetErr error +} + +func (f *fakeRepoForAdapter) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error { + return nil +} +func (f *fakeRepoForAdapter) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return nil, nil +} +func (f *fakeRepoForAdapter) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} +func (f *fakeRepoForAdapter) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} +func (f *fakeRepoForAdapter) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error { + f.resetCalledWith = [4]any{userID, platform, window, newStart} + return f.resetErr +} +func (f *fakeRepoForAdapter) UpsertForUser(_ context.Context, userID int64, records []UserPlatformQuotaRecord) error { + f.upsertCalledUserID = userID + f.upsertCalledWith = records + return f.upsertErr +} + +func TestGenericAdapter_UpsertForUser_ForwardsRecords(t *testing.T) { + fake := &fakeRepoForAdapter{} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.UpsertForUser(context.Background(), 42, nil) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if fake.upsertCalledUserID != 42 { + t.Errorf("forwarded userID = %d, want 42", fake.upsertCalledUserID) + } +} + +func TestGenericAdapter_UpsertForUser_PropagatesError(t *testing.T) { + wantErr := errors.New("boom") + fake := &fakeRepoForAdapter{upsertErr: wantErr} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.UpsertForUser(context.Background(), 1, nil) + if !errors.Is(err, wantErr) { + t.Errorf("expected %v, got %v", wantErr, err) + } +} + +func TestGenericAdapter_ResetExpiredWindow_ForwardsAllParams(t *testing.T) { + fake := &fakeRepoForAdapter{} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + now := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) + if err := adapter.ResetExpiredWindow(context.Background(), 7, "openai", "weekly", now); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if fake.resetCalledWith[0].(int64) != 7 || + fake.resetCalledWith[1].(string) != "openai" || + fake.resetCalledWith[2].(string) != "weekly" || + !fake.resetCalledWith[3].(time.Time).Equal(now) { + t.Errorf("forwarded params mismatch: %+v", fake.resetCalledWith) + } +} + +func TestGenericAdapter_ResetExpiredWindow_PropagatesError(t *testing.T) { + wantErr := errors.New("not found") + fake := &fakeRepoForAdapter{resetErr: wantErr} + adapter := NewUserPlatformQuotaServiceAdapter(fake) + + err := adapter.ResetExpiredWindow(context.Background(), 1, "a", "daily", time.Now()) + if !errors.Is(err, wantErr) { + t.Errorf("expected %v, got %v", wantErr, err) + } +} diff --git a/backend/internal/repository/user_platform_quota_repo.go b/backend/internal/repository/user_platform_quota_repo.go new file mode 100644 index 00000000..1e2e7f51 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo.go @@ -0,0 +1,416 @@ +package repository + +import ( + "context" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" +) + +// UserPlatformQuotaRecord 是 repository 层的传输结构体, +// 与 ent.UserPlatformQuota 实体解耦,供业务层使用。 +type UserPlatformQuotaRecord struct { + UserID int64 + Platform string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + +// ErrUserPlatformQuotaNotFound 用于 ResetExpiredWindow 等需要"必须命中已有记录"的方法。 +var ErrUserPlatformQuotaNotFound = fmt.Errorf("user platform quota record not found") + +// UserPlatformQuotaRepository 定义用户平台配额的数据访问接口。 +type UserPlatformQuotaRepository interface { + // BulkInsertInitial 幂等批量插入初始配额记录(ON CONFLICT DO NOTHING)。 + BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error + // GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。 + GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) + // ListByUser 查询用户的所有平台配额记录(排除软删除)。 + ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) + // IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。 + IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error + // ResetExpiredWindow 重置指定窗口(daily/weekly/monthly)的用量与起始时间。 + ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error + // UpsertForUser 全量替换该用户所有平台限额配置(详见 service.UserPlatformQuotaRepository.UpsertForUser)。 + UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error +} + +type userPlatformQuotaRepository struct { + client *dbent.Client +} + +// NewUserPlatformQuotaRepository 创建 UserPlatformQuotaRepository 实现。 +func NewUserPlatformQuotaRepository(client *dbent.Client) UserPlatformQuotaRepository { + return &userPlatformQuotaRepository{client: client} +} + +// BulkInsertInitial 用原生 SQL ON CONFLICT 实现幂等批量插入(带条件 limit 覆盖)。 +// 仅插入 limit_usd 与元数据,usage_usd 用 DB 默认 0,window_start 留 NULL。 +// FK 约束要求 user_id 在 users 表中存在,调用方负责保证。 +// +// 冲突策略:CASE WHEN existing.*_limit_usd IS NULL THEN EXCLUDED.*_limit_usd ELSE existing ... +// - 若 IncrementUsageWithReset 因时序问题已先建行(limit 全 NULL), +// 此处会把注册时的默认 limit 写入,避免该用户在该平台永久无限额。 +// - 若管理员已通过 UpsertForUser 设置了非 NULL 个性化 limit,**保留不动** +// —— 旧实现无条件 EXCLUDED 覆盖会丢失个性化配置。 +// - 不会改 usage_usd / window_start,保留累计的用量。 +// - 仅命中 deleted_at IS NULL 的活跃记录(partial unique index 作用域)。 +func (r *userPlatformQuotaRepository) BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error { + if len(records) == 0 { + return nil + } + + client := clientFromContext(ctx, r.client) + + var sb strings.Builder + _, _ = sb.WriteString("INSERT INTO user_platform_quotas (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) VALUES ") + args := make([]any, 0, len(records)*6) + // 统一时间戳:避免循环内多次 time.Now() 让同一批记录的 created_at/updated_at + // 出现亚毫秒级偏差(与 UpsertForUser 的 now := time.Now() 风格一致)。 + now := time.Now() + for i, rec := range records { + base := i * 6 + if i > 0 { + _, _ = sb.WriteString(",") + } + fmt.Fprintf(&sb, "($%d,$%d,$%d,$%d,$%d,0,0,0,$%d,$%d)", + base+1, base+2, base+3, base+4, base+5, base+6, base+6) + args = append(args, + rec.UserID, rec.Platform, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, + now, + ) + } + // 精确命中 partial unique index(deleted_at IS NULL),避免对软删记录的歧义冲突。 + // 条件覆盖:仅在现有 limit 为 NULL 时才写入 EXCLUDED,否则保留现有非 NULL 值。 + // - 修复 IncrementUsageWithReset 已用 NULL limit 建行的场景(NULL → 注册默认) + // - 保护管理员通过 UpsertForUser 设置的个性化 limit 不被静默覆盖 + _, _ = sb.WriteString(` ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL + DO UPDATE SET + daily_limit_usd = COALESCE(user_platform_quotas.daily_limit_usd, EXCLUDED.daily_limit_usd), + weekly_limit_usd = COALESCE(user_platform_quotas.weekly_limit_usd, EXCLUDED.weekly_limit_usd), + monthly_limit_usd = COALESCE(user_platform_quotas.monthly_limit_usd, EXCLUDED.monthly_limit_usd), + updated_at = EXCLUDED.updated_at`) + + _, err := client.ExecContext(ctx, sb.String(), args...) + return err +} + +// GetByUserPlatform 通过 ent 查询单条配额(排除软删除)。未找到返回 (nil, nil)。 +func (r *userPlatformQuotaRepository) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) { + client := clientFromContext(ctx, r.client) + entity, err := client.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ). + Only(ctx) + if dbent.IsNotFound(err) { + return nil, nil + } + if err != nil { + return nil, err + } + return entQuotaToRecord(entity), nil +} + +// ListByUser 查询用户的所有平台配额记录(排除软删除)。 +func (r *userPlatformQuotaRepository) ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) { + client := clientFromContext(ctx, r.client) + rows, err := client.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.DeletedAtIsNil(), + ). + All(ctx) + if err != nil { + return nil, err + } + out := make([]UserPlatformQuotaRecord, 0, len(rows)) + for _, e := range rows { + out = append(out, *entQuotaToRecord(e)) + } + return out, nil +} + +// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的 *_usage_usd。 +// 行为: +// - 若记录存在:在事务内 SELECT FOR UPDATE,按 (prev_window_start vs current_window_start) +// 判断是否需要重置(不同 = 重置为 cost;相同 = 累加 cost) +// - 若记录不存在(fail-open create 分支):插入新记录,**limit 字段保留 nil(无限制)** +// —— 这是预期行为:billing 链路不能因 quota 表缺失而阻断请求,未注册路径 +// 的用户 quota 默认放行,由调度层指标观测 + 后台对账补建 limit +// +// 上层正常路径(注册时 BulkInsertInitial)保证 limit 在记录创建时就被写入。 +func (r *userPlatformQuotaRepository) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + existing, err := txClient.UserPlatformQuota.Query(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ). + ForUpdate(). + Only(txCtx) + if dbent.IsNotFound(err) { + // fail-open 建行:limit_* 保留 NULL(无限额)。 + // 用 ON CONFLICT DO UPDATE 累加,而非裸 INSERT:并发下另一请求可能在本事务 + // SELECT FOR UPDATE 之后、INSERT 之前刚建行,裸 INSERT 会撞 partial unique index + // 致事务回滚、本次 cost 丢失;DO UPDATE 把 cost 累加到既有 usage 上。 + // 写法与本文件 insertLimitsRow / BulkInsertInitial 的 ON CONFLICT 一致。 + const insertSQL = `INSERT INTO user_platform_quotas + (user_id, platform, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, + daily_window_start, weekly_window_start, monthly_window_start, created_at, updated_at) + VALUES ($1, $2, $3, $3, $3, $4, $5, $6, $7, $7) + ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO UPDATE SET + daily_usage_usd = user_platform_quotas.daily_usage_usd + EXCLUDED.daily_usage_usd, + weekly_usage_usd = user_platform_quotas.weekly_usage_usd + EXCLUDED.weekly_usage_usd, + monthly_usage_usd = user_platform_quotas.monthly_usage_usd + EXCLUDED.monthly_usage_usd, + updated_at = EXCLUDED.updated_at` + // $6 = now:30 天滚动月度窗口以当前时刻为起始 + _, e := txClient.ExecContext(txCtx, insertSQL, + userID, platform, cost, + timezone.StartOfDay(now), timezone.StartOfWeek(now), now, now) + return e + } + if err != nil { + return err + } + + newDaily := maybeReset(existing.DailyUsageUsd, existing.DailyWindowStart, timezone.StartOfDay(now), cost) + newWeekly := maybeReset(existing.WeeklyUsageUsd, existing.WeeklyWindowStart, timezone.StartOfWeek(now), cost) + // 30 天滚动月度窗口:过期时重置为 cost 并以 now 为新起始,否则累加保留原起始 + newMonthly, newMonthlyStart := monthlyMaybeReset(existing.MonthlyUsageUsd, existing.MonthlyWindowStart, cost, now) + + _, e := existing.Update(). + SetDailyUsageUsd(newDaily). + SetWeeklyUsageUsd(newWeekly). + SetMonthlyUsageUsd(newMonthly). + SetDailyWindowStart(timezone.StartOfDay(now)). + SetWeeklyWindowStart(timezone.StartOfWeek(now)). + SetMonthlyWindowStart(newMonthlyStart). // 30 天滚动:仅过期时更新起始 + Save(txCtx) + return e + }) +} + +// ResetExpiredWindow 无条件重置指定窗口(daily/weekly/monthly)的用量与起始时间。 +// +// ⚠️ 命名警告(NOT a "check-then-reset" helper): +// +// 名字里的 "Expired" 是历史遗留,**实现并不校验窗口是否真的过期**。 +// 任何调用都会无条件把对应窗口的 *_usage_usd 清零并重写 *_window_start。 +// 目前唯一合法 caller 是 admin POST /reset 接口(管理员强制归零)。 +// +// 如果你想要"仅在窗口过期才重置"的语义,请直接使用 IncrementUsageWithReset +// 的内部判断(maybeReset / monthlyMaybeReset),或新增独立函数; +// 不要复用这里的函数,否则会出现"明明窗口未过期,用量却被清零"的隐蔽 bug。 +// +// 未命中活跃记录时返回 ErrUserPlatformQuotaNotFound。 +func (r *userPlatformQuotaRepository) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + client := clientFromContext(ctx, r.client) + upd := client.UserPlatformQuota.Update(). + Where( + userplatformquota.UserIDEQ(userID), + userplatformquota.PlatformEQ(platform), + userplatformquota.DeletedAtIsNil(), + ) + switch window { + case "daily": + upd = upd.SetDailyUsageUsd(0).SetDailyWindowStart(newStart) + case "weekly": + upd = upd.SetWeeklyUsageUsd(0).SetWeeklyWindowStart(newStart) + case "monthly": + upd = upd.SetMonthlyUsageUsd(0).SetMonthlyWindowStart(newStart) + default: + return fmt.Errorf("unknown window %q", window) + } + n, err := upd.Save(ctx) + if err != nil { + return err + } + if n == 0 { + return ErrUserPlatformQuotaNotFound + } + return nil +} + +// withTx 在事务中执行 fn,若 ctx 中已有事务则复用。 +func (r *userPlatformQuotaRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx, tx.Client()) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return fmt.Errorf("begin user_platform_quota transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx, tx.Client()); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit user_platform_quota transaction: %w", err) + } + return nil +} + +// entQuotaToRecord 将 ent entity 映射为 repository record。 +// 注意 ent 生成字段名为 DailyLimitUsd(非 DailyLimitUSD)。 +func entQuotaToRecord(e *dbent.UserPlatformQuota) *UserPlatformQuotaRecord { + return &UserPlatformQuotaRecord{ + UserID: e.UserID, + Platform: e.Platform, + DailyLimitUSD: e.DailyLimitUsd, + WeeklyLimitUSD: e.WeeklyLimitUsd, + MonthlyLimitUSD: e.MonthlyLimitUsd, + DailyUsageUSD: e.DailyUsageUsd, + WeeklyUsageUSD: e.WeeklyUsageUsd, + MonthlyUsageUSD: e.MonthlyUsageUsd, + DailyWindowStart: e.DailyWindowStart, + WeeklyWindowStart: e.WeeklyWindowStart, + MonthlyWindowStart: e.MonthlyWindowStart, + } +} + +// maybeReset 判断是否需要重置窗口用量: +// - 若 prevStart 为 nil 或与 currStart 不同,表示窗口已过期,返回 cost(重置) +// - 否则返回 prevUsage + cost(累加) +func maybeReset(prevUsage float64, prevStart *time.Time, currStart time.Time, cost float64) float64 { + if prevStart == nil || !prevStart.Equal(currStart) { + return cost + } + return prevUsage + cost +} + +// monthlyMaybeReset 判断 30 天滚动月度窗口是否需要重置。 +// 过期条件:prevStart 为 nil 或 now - prevStart >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。 +// 过期时重置为 cost,否则累加。返回 (newUsage, newWindowStart)。 +func monthlyMaybeReset(prevUsage float64, prevStart *time.Time, cost float64, now time.Time) (float64, time.Time) { + if prevStart == nil || now.Sub(*prevStart) >= 30*24*time.Hour { + return cost, now + } + return prevUsage + cost, *prevStart +} + +// UpsertForUser 全量替换该用户的所有平台限额(事务内): +// 1. 软删除未在 records 中出现的所有 active 行 +// 2. 对每条 record 尝试 UPDATE(含 deleted_at = NULL 兼容重激活); +// UPDATE 行数为 0 时 INSERT 新行 +// +// 仅改 *_limit_usd + deleted_at + updated_at,保留 *_usage_usd / *_window_start。 +func (r *userPlatformQuotaRepository) UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error { + return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error { + platforms := make([]string, 0, len(records)) + for _, rec := range records { + platforms = append(platforms, rec.Platform) + } + now := time.Now() + if err := softDeleteMissingPlatforms(txCtx, txClient, userID, platforms, now); err != nil { + return err + } + for _, rec := range records { + affected, err := updateLimitsRow(txCtx, txClient, userID, rec, now) + if err != nil { + return err + } + if affected == 0 { + if err := insertLimitsRow(txCtx, txClient, userID, rec, now); err != nil { + return err + } + } + } + return nil + }) +} + +// softDeleteMissingPlatforms 软删除该用户所有不在 keepPlatforms 中的 active 行。 +// keepPlatforms 为空时 → 软删用户所有 active 行。 +// now 由调用方传入,与 updateLimitsRow / insertLimitsRow 共享同一个 Go time.Now(), +// 保证事务内所有时间戳一致(避免 Postgres NOW() 与 Go time.Now() 的微小偏差)。 +func softDeleteMissingPlatforms(ctx context.Context, client *dbent.Client, userID int64, keepPlatforms []string, now time.Time) error { + var ( + query string + args []any + ) + if len(keepPlatforms) == 0 { + query = `UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2 + WHERE user_id = $1 AND deleted_at IS NULL` + args = []any{userID, now} + } else { + placeholders := make([]string, len(keepPlatforms)) + args = make([]any, 0, len(keepPlatforms)+2) + args = append(args, userID, now) + for i, p := range keepPlatforms { + placeholders[i] = fmt.Sprintf("$%d", i+3) + args = append(args, p) + } + query = fmt.Sprintf(`UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2 + WHERE user_id = $1 AND deleted_at IS NULL AND platform NOT IN (%s)`, + strings.Join(placeholders, ",")) + } + _, err := client.ExecContext(ctx, query, args...) + return err +} + +// updateLimitsRow 尝试 UPDATE active 行(deleted_at IS NULL),返回受影响行数。 +// 仅更新 active 行:若存在多条历史软删记录,加 deleted_at IS NULL 守卫可避免 +// 批量重激活导致的 partial unique index(userplatformquota_user_id_platform_uq)冲突。 +// affected=0 时由调用方 UpsertForUser 走 insertLimitsRow 路径创建新行。 +func updateLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) (int64, error) { + const query = `UPDATE user_platform_quotas + SET daily_limit_usd = $1, weekly_limit_usd = $2, monthly_limit_usd = $3, + deleted_at = NULL, updated_at = $4 + WHERE user_id = $5 AND platform = $6 AND deleted_at IS NULL` + res, err := client.ExecContext(ctx, query, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, now, + userID, rec.Platform) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// insertLimitsRow 插入新限额行(usage 默认 0,window_start 默认 NULL)。 +// 带 ON CONFLICT ... DO NOTHING 守卫:防止两个并发请求同时为同一 user/platform 新建行时 +// 触发 unique constraint 违反(userplatformquota_user_id_platform_uq 部分唯一索引)。 +// affected=0 时说明另一个并发请求刚完成 INSERT,fallback 到 updateLimitsRow 覆写 limits 值。 +func insertLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) error { + const query = `INSERT INTO user_platform_quotas + (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, + daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, 0, 0, 0, $6, $6) + ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO NOTHING` + res, err := client.ExecContext(ctx, query, + userID, rec.Platform, + rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, + now) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + // 并发情形:另一请求已插入该行,fallback 到 UPDATE 覆写 limits 值(last-writer-wins)。 + _, err = updateLimitsRow(ctx, client, userID, rec, now) + return err + } + return nil +} diff --git a/backend/internal/repository/user_platform_quota_repo_integration_test.go b/backend/internal/repository/user_platform_quota_repo_integration_test.go new file mode 100644 index 00000000..f02eeaa9 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo_integration_test.go @@ -0,0 +1,269 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// mustCreateUserForQuota 在指定 client 上创建测试用户(满足 FK 约束)。 +func mustCreateUserForQuota(t *testing.T, client *dbent.Client) int64 { + t.Helper() + u := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("quota-test-%d@example.com", time.Now().UnixNano()), + }) + return u.ID +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_Idempotent(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + daily := 5.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily}, + {UserID: userID, Platform: "openai"}, + } + + // 第一次插入 + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "first insert") + // 第二次插入应为 no-op(ON CONFLICT DO NOTHING) + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "second insert (idempotent)") + + list, err := repo.ListByUser(txCtx, userID) + require.NoError(t, err, "list") + require.Len(t, list, 2, "expected 2 records after idempotent insert") + + // 校验 daily_limit_usd 保留 + var anthropicRec *UserPlatformQuotaRecord + for i := range list { + if list[i].Platform == "anthropic" { + anthropicRec = &list[i] + } + } + require.NotNil(t, anthropicRec, "anthropic record should exist") + require.NotNil(t, anthropicRec.DailyLimitUSD, "daily limit should be set") + require.InDelta(t, 5.0, *anthropicRec.DailyLimitUSD, 1e-9, "daily limit should be 5.0") +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_Empty(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + repo := NewUserPlatformQuotaRepository(client) + // 空切片不应报错 + require.NoError(t, repo.BulkInsertInitial(txCtx, nil)) + require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{})) +} + +func TestUserPlatformQuotaRepository_GetByUserPlatform(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 未插入时应返回 nil + rec, err := repo.GetByUserPlatform(txCtx, userID, "anthropic") + require.NoError(t, err, "get before insert should not error") + require.Nil(t, rec, "get before insert should return nil") + + // 插入后查询 + daily := 10.0 + require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily}, + })) + + rec, err = repo.GetByUserPlatform(txCtx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.Equal(t, userID, rec.UserID) + require.Equal(t, "anthropic", rec.Platform) + require.NotNil(t, rec.DailyLimitUSD) + require.InDelta(t, 10.0, *rec.DailyLimitUSD, 1e-9) +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_SameWindow(t *testing.T) { + ctx := context.Background() + + // IncrementUsageWithReset 内部自己开事务,使用独立 ent client 确保跨事务可见 + client := testEntClient(t) + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五 + + // 首次调用:应新建记录 + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.5, now)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.InDelta(t, 1.5, rec.DailyUsageUSD, 1e-9, "initial daily usage") + require.InDelta(t, 1.5, rec.WeeklyUsageUSD, 1e-9, "initial weekly usage") + require.InDelta(t, 1.5, rec.MonthlyUsageUSD, 1e-9, "initial monthly usage") + + // 同一天再次调用:应累加 + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 0.5, now)) + + rec, err = repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "accumulated daily usage") + require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "accumulated weekly usage") + require.InDelta(t, 2.0, rec.MonthlyUsageUSD, 1e-9, "accumulated monthly usage") +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_DailyReset(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + day1 := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五(同一周、同一月) + day2 := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) // 周六(同一周、同一月) + + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.0, day1)) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.0, day2)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.InDelta(t, 1.0, rec.DailyUsageUSD, 1e-9, "daily should reset to 1.0") + require.InDelta(t, 4.0, rec.WeeklyUsageUSD, 1e-9, "weekly should accumulate to 4.0 (same week)") + require.InDelta(t, 4.0, rec.MonthlyUsageUSD, 1e-9, "monthly should accumulate to 4.0 (same month)") +} + +func TestUserPlatformQuotaRepository_IncrementUsageWithReset_WeeklyReset(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 5月22日(周五)和 5月25日(下周一),不同周 + fri := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) + nextMon := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) // 下一周周一 + + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 5.0, fri)) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 2.0, nextMon)) + + rec, err := repo.GetByUserPlatform(ctx, userID, "openai") + require.NoError(t, err) + require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "daily resets to new cost") + require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "weekly resets (new week)") + require.InDelta(t, 7.0, rec.MonthlyUsageUSD, 1e-9, "monthly accumulates (same month)") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + + repo := NewUserPlatformQuotaRepository(client) + + // 先通过 ent 直接建一条记录 + _, err := client.UserPlatformQuota.Create(). + SetUserID(userID). + SetPlatform("gemini"). + SetDailyUsageUsd(10.0). + SetWeeklyUsageUsd(20.0). + SetMonthlyUsageUsd(50.0). + SetDailyWindowStart(time.Date(2026, 5, 21, 0, 0, 0, 0, time.UTC)). + SetWeeklyWindowStart(time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)). + SetMonthlyWindowStart(time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)). + Save(txCtx) + require.NoError(t, err) + + newStart := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + require.NoError(t, repo.ResetExpiredWindow(txCtx, userID, "gemini", "daily", newStart)) + + rec, err := repo.GetByUserPlatform(txCtx, userID, "gemini") + require.NoError(t, err) + require.InDelta(t, 0.0, rec.DailyUsageUSD, 1e-9, "daily usage reset to 0") + require.NotNil(t, rec.DailyWindowStart) + require.True(t, rec.DailyWindowStart.Equal(newStart), "daily window start updated") + // 其他窗口不变 + require.InDelta(t, 20.0, rec.WeeklyUsageUSD, 1e-9, "weekly usage unchanged") + require.InDelta(t, 50.0, rec.MonthlyUsageUSD, 1e-9, "monthly usage unchanged") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow_UnknownWindow(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + repo := NewUserPlatformQuotaRepository(client) + err := repo.ResetExpiredWindow(ctx, 999, "anthropic", "yearly", time.Now()) + require.Error(t, err, "unknown window should return error") +} + +func TestUserPlatformQuotaRepository_BulkInsertInitial_MultiRow(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + txCtx := dbent.NewTxContext(ctx, tx) + client := tx.Client() + + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d1, d2, d3 := 5.0, 10.0, 15.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d2}, + {UserID: userID, Platform: "gemini", DailyLimitUSD: &d3}, + } + require.NoError(t, repo.BulkInsertInitial(txCtx, records), "multi-row insert failed") + + list, err := repo.ListByUser(txCtx, userID) + require.NoError(t, err) + require.Len(t, list, 3, "expected 3 rows, got %d", len(list)) + + // 验证 limit 值与传入一致(防占位符串位) + byPlatform := map[string]*UserPlatformQuotaRecord{} + for i := range list { + byPlatform[list[i].Platform] = &list[i] + } + require.NotNil(t, byPlatform["anthropic"], "anthropic record should exist") + require.NotNil(t, byPlatform["anthropic"].DailyLimitUSD, "anthropic daily limit should be set") + require.InDelta(t, 5.0, *byPlatform["anthropic"].DailyLimitUSD, 1e-9, "anthropic daily_limit = want 5.0") + + require.NotNil(t, byPlatform["openai"], "openai record should exist") + require.NotNil(t, byPlatform["openai"].DailyLimitUSD, "openai daily limit should be set") + require.InDelta(t, 10.0, *byPlatform["openai"].DailyLimitUSD, 1e-9, "openai daily_limit = want 10.0") + + require.NotNil(t, byPlatform["gemini"], "gemini record should exist") + require.NotNil(t, byPlatform["gemini"].DailyLimitUSD, "gemini daily limit should be set") + require.InDelta(t, 15.0, *byPlatform["gemini"].DailyLimitUSD, 1e-9, "gemini daily_limit = want 15.0") +} + +func TestUserPlatformQuotaRepository_ResetExpiredWindow_NotFoundReturnsSentinel(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUserPlatformQuotaRepository(client) + + err := repo.ResetExpiredWindow(ctx, 99999, "anthropic", "daily", time.Now()) + require.True(t, errors.Is(err, ErrUserPlatformQuotaNotFound), + "expected ErrUserPlatformQuotaNotFound, got %v", err) +} diff --git a/backend/internal/repository/user_platform_quota_repo_test.go b/backend/internal/repository/user_platform_quota_repo_test.go new file mode 100644 index 00000000..84377353 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_repo_test.go @@ -0,0 +1,103 @@ +//go:build unit + +package repository + +import ( + "os" + "strings" + "testing" + "time" +) + +func TestMaybeReset(t *testing.T) { + start := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + other := start.AddDate(0, 0, -1) + cases := []struct { + name string + prevUsage float64 + prevStart *time.Time + currStart time.Time + cost float64 + want float64 + }{ + {"nil prev start resets", 10, nil, start, 1.5, 1.5}, + {"different start resets", 10, &other, start, 1.5, 1.5}, + {"same start accumulates", 10, &start, start, 1.5, 11.5}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := maybeReset(c.prevUsage, c.prevStart, c.currStart, c.cost); got != c.want { + t.Errorf("maybeReset = %v, want %v", got, c.want) + } + }) + } +} + +// TestMonthlyMaybeReset_NilStart 验证 prevStart=nil 时重置。 +func TestMonthlyMaybeReset_NilStart(t *testing.T) { + now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + usage, start := monthlyMaybeReset(10.0, nil, 1.5, now) + if usage != 1.5 { + t.Errorf("usage = %v, want 1.5", usage) + } + if !start.Equal(now) { + t.Errorf("start = %v, want %v", start, now) + } +} + +// TestMonthlyMaybeReset_Expired 验证窗口满 30 天时重置(30 天恰好到期)。 +func TestMonthlyMaybeReset_Expired(t *testing.T) { + windowStart := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC) + // now = windowStart + 30d(刚好到期) + now := windowStart.Add(30 * 24 * time.Hour) + usage, start := monthlyMaybeReset(8.0, &windowStart, 2.0, now) + if usage != 2.0 { + t.Errorf("usage = %v, want 2.0 (reset)", usage) + } + if !start.Equal(now) { + t.Errorf("start = %v, want %v (new window)", start, now) + } +} + +// TestMonthlyMaybeReset_CrossMonthBoundary 验证跨自然月时也使用 30 天滚动(不提前重置)。 +// 旧行为:5 月 1 日跨月立即重置;新行为:窗口起始 4 月 20 日,5 月 1 日仅过了 11 天,应累加。 +func TestMonthlyMaybeReset_CrossMonthBoundary(t *testing.T) { + windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + // 5 月 1 日:距起始 11 天,不足 30 天,应累加 + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + usage, start := monthlyMaybeReset(5.0, &windowStart, 1.0, now) + if usage != 6.0 { + t.Errorf("usage = %v, want 6.0 (accumulate, not reset at month boundary)", usage) + } + if !start.Equal(windowStart) { + t.Errorf("start = %v, want %v (preserved)", start, windowStart) + } +} + +// TestMonthlyMaybeReset_Active 验证窗口内正常累加。 +func TestMonthlyMaybeReset_Active(t *testing.T) { + windowStart := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + // 15 天内,窗口有效 + now := windowStart.Add(15 * 24 * time.Hour) + usage, start := monthlyMaybeReset(3.0, &windowStart, 0.5, now) + if usage != 3.5 { + t.Errorf("usage = %v, want 3.5", usage) + } + if !start.Equal(windowStart) { + t.Errorf("start = %v, want %v", start, windowStart) + } +} + +// TestUpdateLimitsRowQuery_HasDeletedAtGuard 通过读取源文件验证 updateLimitsRow +// 的 SQL WHERE 子句包含 deleted_at IS NULL 守卫(I-NEW-1)。 +// 此防回归测试可在无 DB 的 CI 环境中运行,防止意外删除该守卫。 +func TestUpdateLimitsRowQuery_HasDeletedAtGuard(t *testing.T) { + src, err := os.ReadFile("user_platform_quota_repo.go") + if err != nil { + t.Fatalf("failed to read source file: %v", err) + } + const guard = "AND deleted_at IS NULL" + if !strings.Contains(string(src), guard) { + t.Errorf("updateLimitsRow SQL must contain %q to prevent bulk reactivation of soft-deleted rows (I-NEW-1)", guard) + } +} diff --git a/backend/internal/repository/user_platform_quota_service_adapter.go b/backend/internal/repository/user_platform_quota_service_adapter.go new file mode 100644 index 00000000..7495cd26 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_service_adapter.go @@ -0,0 +1,206 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// userPlatformQuotaServiceAdapter 将 repository 层的 userPlatformQuotaRepository +// 适配为 service.UserPlatformQuotaRepository 接口(返回 *service.UserPlatformQuotaRecord)。 +type userPlatformQuotaServiceAdapter struct { + inner *userPlatformQuotaRepository +} + +// NewUserPlatformQuotaServiceAdapter 将 UserPlatformQuotaRepository 实现包装为 +// 满足 service.UserPlatformQuotaRepository 接口的适配器。 +func NewUserPlatformQuotaServiceAdapter(repo UserPlatformQuotaRepository) service.UserPlatformQuotaRepository { + impl, ok := repo.(*userPlatformQuotaRepository) + if !ok { + // 非标准实现(如测试 fake),通过通用适配器包装 + return &genericUserPlatformQuotaAdapter{inner: repo} + } + return &userPlatformQuotaServiceAdapter{inner: impl} +} + +func (a *userPlatformQuotaServiceAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) { + rec, err := a.inner.GetByUserPlatform(ctx, userID, platform) + if err != nil || rec == nil { + return nil, err + } + return toServiceRecord(rec), nil +} + +// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的用量。 +func (a *userPlatformQuotaServiceAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now) +} + +// ListByUser 查询用户的所有平台配额记录。 +func (a *userPlatformQuotaServiceAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) { + rows, err := a.inner.ListByUser(ctx, userID) + if err != nil { + return nil, err + } + out := make([]service.UserPlatformQuotaRecord, len(rows)) + for i, r := range rows { + out[i] = service.UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out, nil +} + +// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 repo。 +func (a *userPlatformQuotaServiceAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error { + repoRecords := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + repoRecords[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + } + } + return a.inner.BulkInsertInitial(ctx, repoRecords) +} + +// UpsertForUser 全量替换该用户所有平台限额。 +func (a *userPlatformQuotaServiceAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + repoRecords := toRepoRecords(records) + return a.inner.UpsertForUser(ctx, userID, repoRecords) +} + +// ResetExpiredWindow 转发至 repository.ResetExpiredWindow,并将 repository sentinel 包装为 service sentinel。 +func (a *userPlatformQuotaServiceAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart) + if errors.Is(err, ErrUserPlatformQuotaNotFound) { + return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err) + } + return err +} + +// genericUserPlatformQuotaAdapter 通过通用接口适配(用于测试 fake 或非标准实现)。 +type genericUserPlatformQuotaAdapter struct { + inner UserPlatformQuotaRepository +} + +func (a *genericUserPlatformQuotaAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) { + rec, err := a.inner.GetByUserPlatform(ctx, userID, platform) + if err != nil || rec == nil { + return nil, err + } + return toServiceRecord(rec), nil +} + +// IncrementUsageWithReset 原子累加 cost(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error { + return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now) +} + +// ListByUser 查询用户的所有平台配额记录(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) { + rows, err := a.inner.ListByUser(ctx, userID) + if err != nil { + return nil, err + } + out := make([]service.UserPlatformQuotaRecord, len(rows)) + for i, r := range rows { + out[i] = service.UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out, nil +} + +// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 generic repo。 +func (a *genericUserPlatformQuotaAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error { + repoRecords := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + repoRecords[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + } + } + return a.inner.BulkInsertInitial(ctx, repoRecords) +} + +// UpsertForUser 全量替换(通用 adapter 实现)。 +func (a *genericUserPlatformQuotaAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error { + repoRecords := toRepoRecords(records) + return a.inner.UpsertForUser(ctx, userID, repoRecords) +} + +// ResetExpiredWindow 转发至 repository.ResetExpiredWindow(通用 adapter),并包装 sentinel。 +func (a *genericUserPlatformQuotaAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error { + err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart) + if errors.Is(err, ErrUserPlatformQuotaNotFound) { + return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err) + } + return err +} + +// toServiceRecord 将 repository.UserPlatformQuotaRecord 转换为 service.UserPlatformQuotaRecord。 +func toServiceRecord(rec *UserPlatformQuotaRecord) *service.UserPlatformQuotaRecord { + return &service.UserPlatformQuotaRecord{ + UserID: rec.UserID, + Platform: rec.Platform, + DailyLimitUSD: rec.DailyLimitUSD, + WeeklyLimitUSD: rec.WeeklyLimitUSD, + MonthlyLimitUSD: rec.MonthlyLimitUSD, + DailyUsageUSD: rec.DailyUsageUSD, + WeeklyUsageUSD: rec.WeeklyUsageUSD, + MonthlyUsageUSD: rec.MonthlyUsageUSD, + DailyWindowStart: rec.DailyWindowStart, + WeeklyWindowStart: rec.WeeklyWindowStart, + MonthlyWindowStart: rec.MonthlyWindowStart, + } +} + +// toRepoRecords 将 service.UserPlatformQuotaRecord 切片转换为 repository.UserPlatformQuotaRecord(含 limit 字段,含 usage/window_start)。 +func toRepoRecords(records []service.UserPlatformQuotaRecord) []UserPlatformQuotaRecord { + out := make([]UserPlatformQuotaRecord, len(records)) + for i, r := range records { + out[i] = UserPlatformQuotaRecord{ + UserID: r.UserID, + Platform: r.Platform, + DailyLimitUSD: r.DailyLimitUSD, + WeeklyLimitUSD: r.WeeklyLimitUSD, + MonthlyLimitUSD: r.MonthlyLimitUSD, + DailyUsageUSD: r.DailyUsageUSD, + WeeklyUsageUSD: r.WeeklyUsageUSD, + MonthlyUsageUSD: r.MonthlyUsageUSD, + DailyWindowStart: r.DailyWindowStart, + WeeklyWindowStart: r.WeeklyWindowStart, + MonthlyWindowStart: r.MonthlyWindowStart, + } + } + return out +} diff --git a/backend/internal/repository/user_platform_quota_upsert_test.go b/backend/internal/repository/user_platform_quota_upsert_test.go new file mode 100644 index 00000000..db8e6f47 --- /dev/null +++ b/backend/internal/repository/user_platform_quota_upsert_test.go @@ -0,0 +1,148 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/userplatformquota" + "github.com/stretchr/testify/require" +) + +func TestUpsertForUser_NewUserInsertsAllRecords(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + daily := 10.0 + weekly := 50.0 + monthly := 200.0 + records := []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily, WeeklyLimitUSD: &weekly, MonthlyLimitUSD: &monthly}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &daily}, + } + require.NoError(t, repo.UpsertForUser(ctx, userID, records)) + + got, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + require.Len(t, got, 2) +} + +func TestUpsertForUser_PartialUpdateSoftDeletesMissingPlatforms(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d1 := 10.0 + d2 := 20.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d1}, + })) + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2}, + {UserID: userID, Platform: "gemini", DailyLimitUSD: &d1}, + })) + + active, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + platforms := map[string]float64{} + for _, r := range active { + require.NotNil(t, r.DailyLimitUSD) + platforms[r.Platform] = *r.DailyLimitUSD + } + require.Len(t, platforms, 2) + require.InDelta(t, 20.0, platforms["anthropic"], 1e-9) + require.InDelta(t, 10.0, platforms["gemini"], 1e-9) + _, openaiActive := platforms["openai"] + require.False(t, openaiActive, "openai should be soft-deleted") +} + +func TestUpsertForUser_PreservesUsageAndWindowStart(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + })) + + now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) + require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.5, now)) + + newD := 50.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &newD}, + })) + + rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, rec) + require.InDelta(t, 50.0, *rec.DailyLimitUSD, 1e-9, "limit should update") + require.InDelta(t, 3.5, rec.DailyUsageUSD, 1e-9, "usage must be preserved") + require.NotNil(t, rec.DailyWindowStart, "window_start must be preserved") +} + +func TestUpsertForUser_ReactivatesSoftDeleted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + })) + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{})) + + gone, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.Nil(t, gone, "anthropic should be soft-deleted (not active)") + + d2 := 20.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2}, + })) + + back, err := repo.GetByUserPlatform(ctx, userID, "anthropic") + require.NoError(t, err) + require.NotNil(t, back, "anthropic should be active again") + require.InDelta(t, 20.0, *back.DailyLimitUSD, 1e-9) + + allRows, err := client.UserPlatformQuota.Query(). + Where(userplatformquota.UserIDEQ(userID), userplatformquota.PlatformEQ("anthropic")). + All(ctx) + require.NoError(t, err) + activeCount := 0 + for _, r := range allRows { + if r.DeletedAt == nil { + activeCount++ + } + } + require.Equal(t, 1, activeCount, "should have exactly one active row") +} + +func TestUpsertForUser_EmptyClearsAll(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + userID := mustCreateUserForQuota(t, client) + repo := NewUserPlatformQuotaRepository(client) + + d := 10.0 + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{ + {UserID: userID, Platform: "anthropic", DailyLimitUSD: &d}, + {UserID: userID, Platform: "openai", DailyLimitUSD: &d}, + })) + + require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{})) + + got, err := repo.ListByUser(ctx, userID) + require.NoError(t, err) + require.Empty(t, got) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3c0ee9cb..2d1b04e3 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -93,6 +93,8 @@ var ProviderSet = wire.NewSet( NewChannelMonitorRequestTemplateRepository, NewContentModerationRepository, NewAffiliateRepository, + NewUserPlatformQuotaRepository, // T14: user × platform quota + NewUserPlatformQuotaServiceAdapter, // T14: adapter → service.UserPlatformQuotaRepository // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index fcf37ed7..8bc9e280 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -798,6 +798,14 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}}, + "auth_source_default_email_platform_quotas": null, + "auth_source_default_github_platform_quotas": null, + "auth_source_default_google_platform_quotas": null, + "auth_source_default_linuxdo_platform_quotas": null, + "auth_source_default_oidc_platform_quotas": null, + "auth_source_default_wechat_platform_quotas": null, + "auth_source_default_dingtalk_platform_quotas": null, "affiliate_rebate_rate": 20, "affiliate_rebate_freeze_hours": 0, "affiliate_rebate_duration_days": 0, @@ -1025,6 +1033,14 @@ func TestAPIContracts(t *testing.T) { "purchase_subscription_url": "", "table_default_page_size": 20, "table_page_size_options": [10, 20, 50], + "default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}}, + "auth_source_default_email_platform_quotas": null, + "auth_source_default_github_platform_quotas": null, + "auth_source_default_google_platform_quotas": null, + "auth_source_default_linuxdo_platform_quotas": null, + "auth_source_default_oidc_platform_quotas": null, + "auth_source_default_wechat_platform_quotas": null, + "auth_source_default_dingtalk_platform_quotas": null, "custom_menu_items": [], "custom_endpoints": [], "default_concurrency": 0, diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 3fbbb716..303d0db8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index a643d3bc..e9922e54 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) @@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) toucher := &recordingActivityToucher{} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e036cf32..5d62a7ab 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -241,6 +241,9 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency) + users.GET("/:id/platform-quotas", h.Admin.User.GetUserPlatformQuotas) + users.PUT("/:id/platform-quotas", h.Admin.User.UpdateUserPlatformQuotas) + users.POST("/:id/platform-quotas/reset", h.Admin.User.ResetUserPlatformQuotaWindow) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index e79d3ee3..07ae33de 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -32,6 +32,7 @@ func RegisterUserRoutes( user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage) + user.GET("/platform-quotas", h.User.GetMyPlatformQuotas) // 通知邮箱管理 notifyEmail := user.Group("/notify-email") diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 2f764d67..d01b11e6 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -459,6 +459,22 @@ func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID panic("unexpected InvalidateAPIKeyRateLimit call") } +func (s *billingCacheStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + panic("unexpected GetUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + panic("unexpected SetUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + panic("unexpected DeleteUserPlatformQuotaCache call") +} + +func (s *billingCacheStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + panic("unexpected IncrUserPlatformQuotaUsageCache call") +} + func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall { t.Helper() calls := make([]subscriptionInvalidateCall, 0, expected) diff --git a/backend/internal/service/auth_email_oauth_auto.go b/backend/internal/service/auth_email_oauth_auto.go index 4db845c2..f86df14c 100644 --- a/backend/internal/service/auth_email_oauth_auto.go +++ b/backend/internal/service/auth_email_oauth_auto.go @@ -189,6 +189,8 @@ func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username, } s.postAuthUserBootstrap(ctx, user, providerType, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { diff --git a/backend/internal/service/auth_email_oauth_auto_test.go b/backend/internal/service/auth_email_oauth_auto_test.go new file mode 100644 index 00000000..d128e6bb --- /dev/null +++ b/backend/internal/service/auth_email_oauth_auto_test.go @@ -0,0 +1,88 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newEmailOAuthAutoAuthService( + userRepo UserRepository, + settings map[string]string, + quotaRepo UserPlatformQuotaRepository, +) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + + return NewAuthService( + nil, // entClient — nil, updateUserSignupSource early return + userRepo, + nil, // redeemRepo — invitationCode="" 时不触发 + &refreshTokenCacheStub{}, + cfg, + settingService, + nil, // emailService + nil, // turnstileService + nil, // emailQueueService + nil, // promoService + nil, // defaultSubAssigner — nil, assignSubscriptions early return + nil, // affiliateService — nil, bindOAuthAffiliate early return + quotaRepo, + ) +} + +func TestEmailOAuthAuto_SnapshotsPlatformQuotaDefaults(t *testing.T) { + userRepo := &userRepoStub{nextID: 88} + quotaRepo := &userPlatformQuotaRepoStub{} + + svc := newEmailOAuthAutoAuthService( + userRepo, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"gemini": {"monthly": 100.0}}`, + }, + quotaRepo, + ) + + user, err := svc.createEmailOAuthUser( + context.Background(), + "newoauth@example.com", + "newoauth", + "github", + "", // invitationCode + "", // affiliateCode + ) + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(88), user.ID) + + require.Len(t, quotaRepo.bulkInsertCalls, 1, "createEmailOAuthUser must snapshot platform quotas via BulkInsertInitial") + + records := quotaRepo.bulkInsertCalls[0] + var geminiRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "gemini" { + geminiRecord = &records[i] + break + } + } + require.NotNil(t, geminiRecord, "expected gemini platform record") + require.NotNil(t, geminiRecord.MonthlyLimitUSD) + require.InDelta(t, 100.0, *geminiRecord.MonthlyLimitUSD, 0.0001) +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index cf0be652..24d0eeee 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -283,6 +283,8 @@ func (s *AuthService) FinalizeOAuthEmailAccount( s.updateOAuthSignupSource(ctx, user.ID, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) return nil } diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index 3c02587b..5a78a348 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -112,6 +112,7 @@ func newOAuthEmailFlowAuthService( refreshTokenCache RefreshTokenCache, settings map[string]string, emailCache EmailCache, + quotaRepo UserPlatformQuotaRepository, // 新增 ) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ @@ -142,6 +143,7 @@ func newOAuthEmailFlowAuthService( nil, nil, nil, + quotaRepo, // 替换原来的 nil ) } @@ -175,6 +177,7 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -215,6 +218,7 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -274,6 +278,7 @@ func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T) SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -313,6 +318,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing SettingKeyEmailVerifyEnabled: "true", }, emailCache, + nil, ) tokenPair, user, err := authService.RegisterOAuthEmailAccount( @@ -360,6 +366,7 @@ func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) SettingKeyInvitationCodeEnabled: "true", }, &emailCacheStub{}, + nil, ) err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123") @@ -382,6 +389,7 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { SettingKeyRegistrationEnabled: "true", }, &emailCacheStub{}, + nil, ) err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "") @@ -389,3 +397,54 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "delete created oauth user") } + +func TestFinalizeOAuthEmailAccount_SnapshotsPlatformQuotaDefaults(t *testing.T) { + userRepo := &userRepoStub{nextID: 99} + quotaRepo := &userPlatformQuotaRepoStub{} + + authService := newOAuthEmailFlowAuthService( + userRepo, + nil, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 5.5}}`, + }, + &emailCacheStub{}, + quotaRepo, + ) + + user := &User{ + ID: 99, + Email: "newuser@example.com", + Role: RoleUser, + Status: StatusActive, + SignupSource: "oidc", + } + + err := authService.FinalizeOAuthEmailAccount( + context.Background(), + user, + "", + "oidc", + "", + ) + + require.NoError(t, err) + + require.Len(t, quotaRepo.bulkInsertCalls, 1, "snapshotPlatformQuotaDefaults must call BulkInsertInitial once on successful OAuth signup") + + records := quotaRepo.bulkInsertCalls[0] + var anthropicRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "anthropic" { + anthropicRecord = &records[i] + break + } + } + require.NotNil(t, anthropicRecord, "expected anthropic platform record") + require.Equal(t, int64(99), anthropicRecord.UserID) + require.NotNil(t, anthropicRecord.DailyLimitUSD) + require.InDelta(t, 5.5, *anthropicRecord.DailyLimitUSD, 0.0001) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 4e5b7b94..e4fa876c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -62,18 +62,19 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - entClient *dbent.Client - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService - affiliateService *AffiliateService - defaultSubAssigner DefaultSubscriptionAssigner + entClient *dbent.Client + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService + affiliateService *AffiliateService + defaultSubAssigner DefaultSubscriptionAssigner + userPlatformQuotaRepo UserPlatformQuotaRepository } type DefaultSubscriptionAssigner interface { @@ -81,9 +82,10 @@ type DefaultSubscriptionAssigner interface { } type signupGrantPlan struct { - Balance float64 - Concurrency int - Subscriptions []DefaultSubscriptionSetting + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + PlatformQuotas map[string]*DefaultPlatformQuotaSetting } // NewAuthService 创建认证服务实例 @@ -100,20 +102,22 @@ func NewAuthService( promoService *PromoService, defaultSubAssigner DefaultSubscriptionAssigner, affiliateService *AffiliateService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *AuthService { return &AuthService{ - entClient: entClient, - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, - affiliateService: affiliateService, - defaultSubAssigner: defaultSubAssigner, + entClient: entClient, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, + affiliateService: affiliateService, + defaultSubAssigner: defaultSubAssigner, + userPlatformQuotaRepo: userPlatformQuotaRepo, } } @@ -226,6 +230,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } s.postAuthUserBootstrap(ctx, user, "email", true) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) if s.affiliateService != nil { if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err) @@ -535,6 +541,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -685,6 +693,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) } } else { @@ -703,6 +713,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema user = newUser s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + // snapshot user × platform quota(fail-open) + _ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan) s.bindOAuthAffiliate(ctx, user.ID, affiliateCode) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -764,18 +776,39 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + // ============ 全局 quota 装载(必须在 ResolveAuthSourceGrantSettings 之前) ============ + // 无论 auth source 是否 enabled,全局层都要先装载,确保 !enabled 早退路径也携带全局 quota。 + if quotas, err := s.settingService.GetDefaultPlatformQuotas(ctx); err == nil { + plan.PlatformQuotas = quotas + } else { + logger.LegacyPrintf("service.auth", "[Auth] Warning: load default platform quotas failed: %v (fail-open)", err) + } + // ============================================================================================ + resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) return plan } if !enabled { - return plan + return plan // plan.PlatformQuotas 已含全局层 } plan.Balance = resolved.Balance plan.Concurrency = resolved.Concurrency plan.Subscriptions = resolved.Subscriptions + + // ============ auth source quota merge(仅在 enabled 分支内) ============ + asQuotas := s.settingService.GetAuthSourcePlatformQuotas(ctx, signupSource) + if plan.PlatformQuotas != nil { + for platform, patch := range asQuotas { + if dst := plan.PlatformQuotas[platform]; dst != nil { + mergePlatformQuotaDefaults(dst, patch) + } + } + } + // ============================================================================== + return plan } @@ -1586,3 +1619,29 @@ func resolvedTokenVersion(user *User) int64 { fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff) return user.TokenVersion ^ fingerprint } + +// snapshotPlatformQuotaDefaults 把 plan.PlatformQuotas(4 platform × 3 window)以 +// BulkInsertInitial 形式写入 user_platform_quotas 表。失败 fail-open(仅 warn log)。 +func (s *AuthService) snapshotPlatformQuotaDefaults(ctx context.Context, userID int64, plan *signupGrantPlan) error { + if s.userPlatformQuotaRepo == nil || plan == nil || len(plan.PlatformQuotas) == 0 { + return nil + } + records := make([]UserPlatformQuotaRecord, 0, len(plan.PlatformQuotas)) + for platform, q := range plan.PlatformQuotas { + rec := UserPlatformQuotaRecord{ + UserID: userID, + Platform: platform, + } + if q != nil { + rec.DailyLimitUSD = q.DailyLimitUSD + rec.WeeklyLimitUSD = q.WeeklyLimitUSD + rec.MonthlyLimitUSD = q.MonthlyLimitUSD + } + records = append(records, rec) + } + if err := s.userPlatformQuotaRepo.BulkInsertInitial(ctx, records); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Warning: snapshot platform quota failed user=%d: %v (fail-open)", userID, err) + return nil // fail-open:返回 nil,让调用方继续 + } + return nil +} diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index 8f03f857..87867395 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( emailSvc = service.NewEmailService(settingRepo, emailCache) } - svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil) + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil, nil) return svc, repo, client } @@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t }, } emailService := service.NewEmailService(nil, cache) - svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil, nil) oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ ID: 41, @@ -810,8 +810,8 @@ func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, t } func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } -func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } -func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) { s.mu.Lock() @@ -820,8 +820,12 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) ( return ok, nil } -func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 53048b92..59ed8e52 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( values: settings, }, cfg) - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil) + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil, nil) return svc, repo, client } diff --git a/backend/internal/service/auth_service_platform_quota_test.go b/backend/internal/service/auth_service_platform_quota_test.go new file mode 100644 index 00000000..f58dc48c --- /dev/null +++ b/backend/internal/service/auth_service_platform_quota_test.go @@ -0,0 +1,157 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" +) + +// fakeInsertRecorder 记录 BulkInsertInitial 调用,实现 UserPlatformQuotaRepository port。 +type fakeInsertRecorder struct { + records []UserPlatformQuotaRecord + err error +} + +func (f *fakeInsertRecorder) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeInsertRecorder) BulkInsertInitial(_ context.Context, recs []UserPlatformQuotaRecord) error { + if f.err != nil { + return f.err + } + f.records = append(f.records, recs...) + return nil +} + +func (f *fakeInsertRecorder) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} + +func (f *fakeInsertRecorder) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeInsertRecorder) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeInsertRecorder) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error { + return nil +} + +func TestSnapshotPlatformQuotaDefaults_PassesToRepoBulkInsert(t *testing.T) { + fakeRepo := &fakeInsertRecorder{} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + "openai": {}, + "gemini": {}, + "antigravity": {}, + }, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 999, plan); err != nil { + t.Fatal(err) + } + if len(fakeRepo.records) != 4 { + t.Errorf("expected 4 records, got %d", len(fakeRepo.records)) + } + found := false + for _, r := range fakeRepo.records { + if r.UserID == 999 && r.Platform == "anthropic" && r.DailyLimitUSD != nil && *r.DailyLimitUSD == 5 { + found = true + } + } + if !found { + t.Error("anthropic daily = 5 not snapshotted") + } +} + +func TestSnapshotPlatformQuotaDefaults_NilPlanIsNoop(t *testing.T) { + fakeRepo := &fakeInsertRecorder{} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, nil); err != nil { + t.Errorf("nil plan should be noop, got %v", err) + } + if len(fakeRepo.records) != 0 { + t.Errorf("expected no records, got %d", len(fakeRepo.records)) + } +} + +func TestSnapshotPlatformQuotaDefaults_RepoErrorFailsOpen(t *testing.T) { + fakeRepo := &fakeInsertRecorder{err: fmt.Errorf("db down")} + s := &AuthService{userPlatformQuotaRepo: fakeRepo} + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &five}, + }, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil { + t.Errorf("fail-open: expected nil even on repo error, got %v", err) + } +} + +func TestSnapshotPlatformQuotaDefaults_NilRepoIsNoop(t *testing.T) { + s := &AuthService{userPlatformQuotaRepo: nil} + five := 5.0 + plan := &signupGrantPlan{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{"a": {DailyLimitUSD: &five}}, + } + if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil { + t.Errorf("nil repo should be noop, got %v", err) + } +} + +// resolveSignupGrantPlan 测试:依赖完整的 AuthService 构造,需要 SettingService(含 settingRepoStub)。 +// settingRepoStub 已在 auth_service_register_test.go 中定义,同 package 可直接使用。 +func TestResolveSignupGrantPlan_GlobalQuotaLoadedBeforeAuthSource(t *testing.T) { + // 全局 quota JSON key(新格式) + settings := map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{ + "anthropic": {"daily": 10, "weekly": 50, "monthly": 200}, + "openai": {"daily": 5, "weekly": 25, "monthly": 100}, + "gemini": {"daily": 5, "weekly": 25, "monthly": 100}, + "antigravity": {"daily": 5, "weekly": 25, "monthly": 100} + }`, + } + svc := newAuthService(nil, settings, nil, nil) + plan := svc.resolveSignupGrantPlan(context.Background(), "email") + if plan.PlatformQuotas == nil { + t.Fatal("expected PlatformQuotas to be non-nil after loading global quota KVs") + } + q := plan.PlatformQuotas["anthropic"] + if q == nil { + t.Fatal("expected anthropic quota to be set") + } + if q.DailyLimitUSD == nil || *q.DailyLimitUSD != 10 { + t.Errorf("expected anthropic daily=10, got %v", q.DailyLimitUSD) + } +} + +// TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota 验证 P1 约束: +// !enabled 早退路径仍携带全局 quota(GetDefaultPlatformQuotas 在 ResolveAuthSourceGrantSettings 之前)。 +func TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota(t *testing.T) { + settings := map[string]string{ + SettingKeyRegistrationEnabled: "true", + // auth source 不配置(=> !enabled 路径) + SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 10, "weekly": 50, "monthly": 200}}`, + } + svc := newAuthService(nil, settings, nil, nil) + plan := svc.resolveSignupGrantPlan(context.Background(), "email") + // !enabled 路径:plan.PlatformQuotas 应已含全局层(不是 nil) + if plan.PlatformQuotas == nil { + t.Fatal("P1 violated: PlatformQuotas is nil even with global quota KVs set") + } + // P1 核心断言:disabled auth source 路径不能丢失全局 quota + if _, ok := plan.PlatformQuotas["anthropic"]; !ok { + t.Error("P1 violated: disabled auth source path dropped global platform quota") + } +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index ece02474..a7c0d260 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -73,6 +73,38 @@ type defaultSubscriptionAssignerStub struct { type refreshTokenCacheStub struct{} +type userPlatformQuotaRepoStub struct { + bulkInsertCalls [][]UserPlatformQuotaRecord + bulkInsertErr error +} + +func (s *userPlatformQuotaRepoStub) BulkInsertInitial(_ context.Context, records []UserPlatformQuotaRecord) error { + cloned := make([]UserPlatformQuotaRecord, len(records)) + copy(cloned, records) + s.bulkInsertCalls = append(s.bulkInsertCalls, cloned) + return s.bulkInsertErr +} + +func (s *userPlatformQuotaRepoStub) GetByUserPlatform(context.Context, int64, string) (*UserPlatformQuotaRecord, error) { + panic("unexpected GetByUserPlatform call") +} + +func (s *userPlatformQuotaRepoStub) ListByUser(context.Context, int64) ([]UserPlatformQuotaRecord, error) { + panic("unexpected ListByUser call") +} + +func (s *userPlatformQuotaRepoStub) IncrementUsageWithReset(context.Context, int64, string, float64, time.Time) error { + panic("unexpected IncrementUsageWithReset call") +} + +func (s *userPlatformQuotaRepoStub) UpsertForUser(context.Context, int64, []UserPlatformQuotaRecord) error { + panic("unexpected UpsertForUser call") +} + +func (s *userPlatformQuotaRepoStub) ResetExpiredWindow(context.Context, int64, string, string, time.Time) error { + panic("unexpected ResetExpiredWindow call") +} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -178,7 +210,7 @@ func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int6 return 0, nil } -func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { +func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache, quotaRepo UserPlatformQuotaRepository) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", @@ -213,6 +245,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, // promoService nil, // defaultSubAssigner nil, // affiliateService + quotaRepo, ) } @@ -220,7 +253,7 @@ func TestAuthService_Register_Disabled(t *testing.T) { repo := &userRepoStub{} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "false", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrRegDisabled) @@ -229,19 +262,62 @@ func TestAuthService_Register_Disabled(t *testing.T) { func TestAuthService_Register_DisabledByDefault(t *testing.T) { // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 repo := &userRepoStub{} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrRegDisabled) } +func TestAuthService_Register_SnapshotsPlatformQuotaDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 77} + quotaRepo := &userPlatformQuotaRepoStub{} + + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultPlatformQuotas: `{"openai": {"weekly": 12.34}}`, + }, nil, quotaRepo) + + _, user, err := service.Register(context.Background(), "newuser@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + + require.Len(t, quotaRepo.bulkInsertCalls, 1) + + records := quotaRepo.bulkInsertCalls[0] + var openaiRecord *UserPlatformQuotaRecord + for i := range records { + if records[i].Platform == "openai" { + openaiRecord = &records[i] + break + } + } + require.NotNil(t, openaiRecord, "expected openai platform record") + require.Equal(t, int64(77), openaiRecord.UserID) + require.NotNil(t, openaiRecord.WeeklyLimitUSD) + require.InDelta(t, 12.34, *openaiRecord.WeeklyLimitUSD, 0.0001) +} + +func TestAuthService_Register_DoesNotSnapshotOnDisabled(t *testing.T) { + repo := &userRepoStub{} + quotaRepo := &userPlatformQuotaRepoStub{} + + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "false", + }, nil, quotaRepo) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) + + require.Empty(t, quotaRepo.bulkInsertCalls, "registration rejected before user creation must not snapshot") +} + func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, nil) + }, nil, nil) // 应返回服务不可用错误,而不是允许绕过验证 _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "") @@ -254,7 +330,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, cache) + }, cache, nil) _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) @@ -268,7 +344,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - }, cache) + }, cache, nil) _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "") require.ErrorIs(t, err, ErrInvalidVerifyCode) @@ -279,7 +355,7 @@ func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -289,7 +365,7 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -299,7 +375,7 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) { repo := &userRepoStub{} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") require.ErrorIs(t, err, ErrEmailReserved) @@ -310,7 +386,7 @@ func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@other.com", "password") require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) @@ -327,7 +403,7 @@ func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`, - }, nil) + }, nil, nil) _, user, err := service.Register(context.Background(), "user@example.com", "password") require.NoError(t, err) @@ -340,7 +416,7 @@ func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, - }, nil) + }, nil, nil) err := service.SendVerifyCode(context.Background(), "user@other.com") require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) @@ -354,7 +430,7 @@ func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -365,7 +441,7 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { repo := &userRepoStub{createErr: ErrEmailExists} service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", - }, nil) + }, nil, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -376,7 +452,7 @@ func TestAuthService_Register_Success(t *testing.T) { service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) @@ -394,7 +470,7 @@ func TestAuthService_Register_Success(t *testing.T) { func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { repo := &userRepoStub{} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) // 创建用户并生成 token user := &User{ @@ -436,7 +512,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { TokenVersion: 1, } repo := &userRepoStub{user: user} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, nil, nil, nil) // 创建过期 token service.cfg.JWT.ExpireHour = -1 @@ -453,7 +529,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { } func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 0 @@ -461,7 +537,7 @@ func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) } func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 90 @@ -469,7 +545,7 @@ func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { } func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 0 @@ -494,7 +570,7 @@ func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { } func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { - service := newAuthService(&userRepoStub{}, nil, nil) + service := newAuthService(&userRepoStub{}, nil, nil, nil) service.cfg.JWT.ExpireHour = 24 service.cfg.JWT.AccessTokenExpireMinutes = 90 @@ -525,7 +601,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { SettingKeyRegistrationEnabled: "true", SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") @@ -549,7 +625,7 @@ func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *tes SettingKeyAuthSourceDefaultEmailConcurrency: "7", SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") @@ -572,7 +648,7 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes SettingKeyAuthSourceDefaultEmailConcurrency: "88", SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-global@test.com", "password") @@ -595,7 +671,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul SettingKeyAuthSourceDefaultEmailConcurrency: "5", SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`, SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner _, user, err := service.Register(context.Background(), "email-merged@test.com", "password") @@ -618,7 +694,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} @@ -654,7 +730,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", - }, nil) + }, nil, nil) service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} @@ -677,7 +753,7 @@ func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.Ding DingTalk: dtCfg, } settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) - return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil) + return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil, nil) } // minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。 diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 3512822f..ce99c44d 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -55,6 +55,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier nil, // promoService nil, // defaultSubAssigner nil, // affiliateService + nil, // userPlatformQuotaRepo ) } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 050db55b..2b7c06ba 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -11,18 +11,31 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "golang.org/x/sync/singleflight" ) // 错误定义 // 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 +// errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时 +// 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。 +var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable") + var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") + + // user × platform quota(HTTP 429 Too Many Requests + Retry-After header)。 + // 选用 429 而非 403:限额耗尽属于"暂时性资源用尽,重试可恢复"的场景(RFC 6585), + // 大量 SDK(如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After, + // 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。 + ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.") + ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.") + ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -94,6 +107,7 @@ type BillingCacheService struct { userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker + userPlatformQuotaRepo UserPlatformQuotaRepository cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup @@ -101,6 +115,7 @@ type BillingCacheService struct { cacheWriteMu sync.RWMutex stopped atomic.Bool balanceLoadSF singleflight.Group + quotaLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -117,6 +132,7 @@ func NewBillingCacheService( userRPMCache UserRPMCache, userGroupRateRepo UserGroupRateRepository, cfg *config.Config, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *BillingCacheService { svc := &BillingCacheService{ cache: cache, @@ -126,6 +142,7 @@ func NewBillingCacheService( userRPMCache: userRPMCache, userGroupRateRepo: userGroupRateRepo, cfg: cfg, + userPlatformQuotaRepo: userPlatformQuotaRepo, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() @@ -655,6 +672,30 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co }) } +// IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。 +// +// 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage, +// 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。 +// 写延迟通常 < 1ms(本地 Redis),换取 quota 视图实时性的取舍合理。 +// +// Redis 写失败用 ALERT 级 log;DB 持久化由 caller 单独 goroutine 兜底(gateway_service.go)。 +func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) { + if s.cache == nil { + return + } + if platform == "" || cost <= 0 { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil { + logger.LegacyPrintf("service.billing_cache", + "ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v", + userID, platform, cost, err) + } +} + // ============================================ // 统一检查方法 // ============================================ @@ -662,7 +703,8 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co // CheckBillingEligibility 检查用户是否有资格发起请求 // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) -func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error { +// platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。 +func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error { // 简易模式:跳过所有计费检查 if s.cfg.RunMode == config.RunModeSimple { return nil @@ -684,6 +726,13 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } } + // user × platform quota 仅在 standard(余额)模式生效;订阅模式豁免 + if !isSubscriptionMode { + if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil { + return err + } + } + // Check API Key rate limits (applies to both billing modes) if apiKey != nil && apiKey.HasRateLimits() { if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { @@ -975,3 +1024,257 @@ func circuitStateString(state billingCircuitBreakerState) string { return "unknown" } } + +// checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。 +// 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata)。 +// checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。 +// +// 流程(Redis-first / DB-fallback): +// 1. 先读 Redis cache;若命中且 SchemaVersion==1,直接用 entry 中的 limits 和 window_start 做校验, +// 免除 DB 查询。 +// 2. cache MISS 或旧版 entry(SchemaVersion==0)→ 查 DB 回填完整 entry(含 limits/window_start)。 +// 3. Redis 故障(err != nil)→ fail-open,查 DB 做一次性检查,不回填。 +func (s *BillingCacheService) checkUserPlatformQuotaEligibility( + ctx context.Context, + userID int64, + platform string, +) error { + if platform == "" || s.userPlatformQuotaRepo == nil { + return nil + } + + // cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。 + // 其他 check* 方法(balance/subscription/rate-limit)也有类似守卫。 + var ( + entry *UserPlatformQuotaCacheEntry + ok bool + cacheErr error + ) + if s.cache != nil { + entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform) + } else { + // 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查 + cacheErr = errBillingCacheUnavailable + } + + // --- cache HIT with current schema → 直接用 entry,不查 DB --- + if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 { + now := time.Now() + dailyUsage := entry.DailyUsageUSD + weeklyUsage := entry.WeeklyUsageUSD + monthlyUsage := entry.MonthlyUsageUSD + // 若窗口已更新(DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较, + // 同时记录新窗口起点用于后续刷新 cache entry。 + // 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力, + // 持久化数据始终正确。 + windowExpired := false + newDailyStart := entry.DailyWindowStart + newWeeklyStart := entry.WeeklyWindowStart + newMonthlyStart := entry.MonthlyWindowStart + if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) { + dailyUsage = 0 + windowExpired = true + dayStart := timezone.StartOfDay(now) + newDailyStart = &dayStart + } + if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) { + weeklyUsage = 0 + windowExpired = true + weekStart := timezone.StartOfWeek(now) + newWeeklyStart = &weekStart + } + if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) { + monthlyUsage = 0 + windowExpired = true + monthStart := now + newMonthlyStart = &monthStart + } + // 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis(而非 Delete)。 + // 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到 + // EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。 + // 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。 + // 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。 + // 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。 + // 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。 + if windowExpired && s.cache != nil { + refreshed := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: dailyUsage, + WeeklyUsageUSD: weeklyUsage, + MonthlyUsageUSD: monthlyUsage, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: entry.DailyLimitUSD, + WeeklyLimitUSD: entry.WeeklyLimitUSD, + MonthlyLimitUSD: entry.MonthlyLimitUSD, + DailyWindowStart: newDailyStart, + WeeklyWindowStart: newWeeklyStart, + MonthlyWindowStart: newMonthlyStart, + } + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil { + logger.LegacyPrintf("service.billing_cache", + "Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v", + userID, platform, setErr) + } + setCancel() + } + if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now)) + } + return nil + } + + // --- cache MISS、旧版 entry 或 Redis 故障 → 查 DB(singleflight 合并并发回源)--- + // 使用 DoChan 而非 Do:avoid sharing the first caller's ctx among all dedupe followers. + // 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。 + sfKey := strconv.FormatInt(userID, 10) + ":" + platform + ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) { + // 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx, + // 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer bgCancel() + return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform) + }) + var ( + v any + dbErr error + ) + select { + case res := <-ch: + v, dbErr = res.Val, res.Err + case <-ctx.Done(): + // 当前 caller 的 ctx 被取消:fail-open,不阻断 (此请求已无意义)。 + logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err()) + return nil + } + if dbErr != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr) + return nil + } + rec, _ := v.(*UserPlatformQuotaRecord) + if rec == nil { + return nil + } + + now := time.Now() + dailyUsage := rec.DailyUsageUSD + weeklyUsage := rec.WeeklyUsageUSD + monthlyUsage := rec.MonthlyUsageUSD + if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) { + dailyUsage = 0 + } + if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) { + weeklyUsage = 0 + } + if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) { + monthlyUsage = 0 + } + + // Redis 故障时 fail-open:不回填,直接用 DB 数据做一次性检查 + if cacheErr != nil { + if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) + } + return nil + } + + // cache MISS 或旧版 entry → 回填完整 entry(含 limits 和 window_start) + newEntry := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: dailyUsage, + WeeklyUsageUSD: weeklyUsage, + MonthlyUsageUSD: monthlyUsage, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + DailyLimitUSD: rec.DailyLimitUSD, + WeeklyLimitUSD: rec.WeeklyLimitUSD, + MonthlyLimitUSD: rec.MonthlyLimitUSD, + DailyWindowStart: rec.DailyWindowStart, + WeeklyWindowStart: rec.WeeklyWindowStart, + MonthlyWindowStart: rec.MonthlyWindowStart, + } + if s.cache != nil { + ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second + // 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms, + // 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败, + // 让下一次 preflight 仍然 MISS 并击穿到 DB(高并发下增大 DB 压力)。 + setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr) + } + setCancel() + } + + if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now)) + } + if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now)) + } + if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD { + return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now)) + } + return nil +} + +// withWindowResetsMetadata 给 quota error 附加 window_resets_at metadata(RFC3339)。 +func withWindowResetsMetadata(err error, resetAt time.Time) error { + appErr, ok := err.(*infraerrors.ApplicationError) + if !ok || appErr == nil { + return err + } + return appErr.WithMetadata(map[string]string{ + "window_resets_at": resetAt.Format(time.RFC3339), + }) +} + +// nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。 +// 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。 +func nextDailyReset(now time.Time) time.Time { + return timezone.StartOfDay(now).AddDate(0, 0, 1) +} + +// nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。 +// 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。 +func nextWeeklyReset(now time.Time) time.Time { + return timezone.StartOfWeek(now).AddDate(0, 0, 7) +} + +// nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间(start + 30d)。 +// start 为 nil(未初始化)或已过期(now-start >= 30d,与 monthlyQuotaWindowExpired 同口径)时 +// 退化为 now+30d:过期窗口会在下次 increment 时重置为 now,下次重置即 now+30d; +// 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。 +func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time { + if start == nil || now.Sub(*start) >= 30*24*time.Hour { + return now.Add(30 * 24 * time.Hour) + } + return start.Add(30 * 24 * time.Hour) +} + +// quotaWindowExpired 判断窗口是否已过期:start 为 nil(未初始化)或在 currWindowStart 之前视为已过期。 +func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool { + if start == nil { + return true + } + return start.Before(currWindowStart) +} + +// monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。 +// 过期条件:now - start >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。 +// start 为 nil 时视为已过期(未初始化窗口)。 +func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool { + if start == nil { + return true + } + return now.Sub(*start) >= 30*24*time.Hour +} diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go index de66136f..cb71b886 100644 --- a/backend/internal/service/billing_cache_service_rpm_test.go +++ b/backend/internal/service/billing_cache_service_rpm_test.go @@ -74,7 +74,7 @@ func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGrou t.Helper() // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。 // 我们只直接测 checkRPM。 - svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}) + svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}, nil) t.Cleanup(svc.Stop) return svc } diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 962becf0..b443d97e 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -67,6 +67,22 @@ func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, ke return nil } +func (s *billingCacheMissStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (s *billingCacheMissStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *billingCacheMissStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return nil +} + +func (s *billingCacheMissStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + return nil +} + type balanceLoadUserRepoStub struct { mockUserRepo calls atomic.Int64 @@ -100,7 +116,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}, nil) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 849e24b8..bcd086fa 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -68,9 +68,25 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, return nil } +func (b *billingCacheWorkerStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (b *billingCacheWorkerStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error { + return nil +} + +func (b *billingCacheWorkerStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error { + return nil +} + +func (b *billingCacheWorkerStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + return nil +} + func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil) t.Cleanup(svc.Stop) start := time.Now() @@ -92,7 +108,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_cache_service_user_platform_quota_test.go b/backend/internal/service/billing_cache_service_user_platform_quota_test.go new file mode 100644 index 00000000..57697ddb --- /dev/null +++ b/backend/internal/service/billing_cache_service_user_platform_quota_test.go @@ -0,0 +1,595 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" +) + +// fakeIncrCache 仅记录 IncrUserPlatformQuotaUsageCache 被调用的参数。 +type fakeIncrCache struct { + BillingCache + calls []incrCall +} + +type incrCall struct { + userID int64 + platform string + cost float64 + ttl time.Duration +} + +func (f *fakeIncrCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error { + f.calls = append(f.calls, incrCall{userID, platform, cost, ttl}) + return nil +} + +// IncrementUserPlatformQuotaUsage 已改为同步直写,不再走 worker。 +// 测试验证:同步调用立即调到 cache.IncrUserPlatformQuotaUsageCache。 +func TestIncrementUserPlatformQuotaUsage_SyncCallsCache(t *testing.T) { + fake := &fakeIncrCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 120 + + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + } + + s.IncrementUserPlatformQuotaUsage(101, "anthropic", 0.25) + s.IncrementUserPlatformQuotaUsage(101, "openai", 0.50) + + if len(fake.calls) != 2 { + t.Fatalf("expected 2 incr calls, got %d", len(fake.calls)) + } + if fake.calls[0] != (incrCall{101, "anthropic", 0.25, 120 * time.Second}) { + t.Errorf("call[0] = %+v", fake.calls[0]) + } + if fake.calls[1] != (incrCall{101, "openai", 0.50, 120 * time.Second}) { + t.Errorf("call[1] = %+v", fake.calls[1]) + } +} + +// ── T6 tests: checkUserPlatformQuotaEligibility ────────────────────────────── + +// fakeQuotaRepo 实现 UserPlatformQuotaRepository 最小子集 +type fakeQuotaRepo struct { + rec *UserPlatformQuotaRecord +} + +func (f *fakeQuotaRepo) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) { + return f.rec, nil +} + +func (f *fakeQuotaRepo) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeQuotaRepo) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error { + return nil +} + +func (f *fakeQuotaRepo) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) { + return nil, nil +} + +func (f *fakeQuotaRepo) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error { + return nil +} + +func (f *fakeQuotaRepo) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error { + return nil +} + +// fakeFullCache 同时支持 Get + Set + Incr + Delete。 +// mu 保护 entry 和 deleteCalls,防止异步 goroutine 与主 goroutine 之间的 data race。 +type fakeFullCache struct { + BillingCache + mu sync.Mutex + entry *UserPlatformQuotaCacheEntry + deleteCalls int +} + +// getDeleteCalls 线程安全地读取 deleteCalls。 +func (f *fakeFullCache) getDeleteCalls() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.deleteCalls +} + +// getEntry 线程安全地读取 entry。 +func (f *fakeFullCache) getEntry() *UserPlatformQuotaCacheEntry { + f.mu.Lock() + defer f.mu.Unlock() + return f.entry +} + +func (f *fakeFullCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.entry == nil { + return nil, false, nil + } + return f.entry, true, nil +} + +func (f *fakeFullCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, e *UserPlatformQuotaCacheEntry, _ time.Duration) error { + f.mu.Lock() + defer f.mu.Unlock() + f.entry = e + return nil +} + +func (f *fakeFullCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error { + f.mu.Lock() + defer f.mu.Unlock() + f.deleteCalls++ + f.entry = nil + return nil +} + +func newServiceForPreflight(t *testing.T, repo UserPlatformQuotaRepository, cache BillingCache) *BillingCacheService { + t.Helper() + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + return &BillingCacheService{ + cache: cache, + cfg: cfg, + userPlatformQuotaRepo: repo, + } +} + +// currentDayStart 返回全局时区当天 0 点(与生产 timezone.StartOfDay 同口径,确保窗口有效)。 +func currentDayStart() *time.Time { + s := timezone.StartOfDay(time.Now()) + return &s +} + +func TestCheckUserPlatformQuotaEligibility_AllowsWhenUnderLimit(t *testing.T) { + daily := 10.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 4.5, + DailyLimitUSD: &daily, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_DailyExhausted(t *testing.T) { + daily := 5.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 5.0, + DailyLimitUSD: &daily, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("expected ErrUserPlatformDailyQuotaExhausted, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_NilLimitMeansUnlimited(t *testing.T) { + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 999, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + // DailyLimitUSD nil → 无限额 + }} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("nil limits should be unlimited, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_ZeroLimitImmediateBlock(t *testing.T) { + zero := 0.0 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &zero, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 0, + DailyLimitUSD: &zero, + DailyWindowStart: currentDayStart(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("expected daily exhausted for limit=0, got %v", err) + } +} + +func TestCheckUserPlatformQuotaEligibility_NoRecordMeansUnlimited(t *testing.T) { + repo := &fakeQuotaRepo{rec: nil} + cache := &fakeFullCache{} + s := newServiceForPreflight(t, repo, cache) + if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil { + t.Errorf("no record = unlimited, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB 验证旧版 entry(SchemaVersion=0) +// 触发 DB 回退路径,并在 DB 数据判断配额是否超限。 +// DB record 需设置有效的 window_start,否则 quotaWindowExpired 会将 usage 归零(nil 窗口视为已过期)。 +func TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB(t *testing.T) { + daily := 5.0 + dayStart := currentDayStart() + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, DailyUsageUSD: 6.0, + DailyWindowStart: dayStart, + }} + // SchemaVersion=0(旧 entry),应走 DB 路径 + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{DailyUsageUSD: 1.0}} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("旧版 entry 应走 DB 路径并报 daily exhausted, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache 验证 cache HIT 时若窗口已过期,usage 归零,用户放行。 +func TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache(t *testing.T) { + daily := 5.0 + past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) // 远古窗口起始,肯定已过期 + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 10.0, // 超限,但窗口已过期 + DailyLimitUSD: &daily, + DailyWindowStart: &past, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if err != nil { + t.Errorf("过期窗口应归零放行, got %v", err) + } +} + +// TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache 验证: +// V1 HIT 路径检测到窗口过期时,用 reset 后的 entry 同步覆盖 Redis(而非 Delete): +// 1. 当前请求以本地清零值判断 → 放行 +// 2. cache entry 被替换为新 entry: usage 清零 + window_start 更新到当前窗口 +// limit 保留;这样并发 IncrUserPlatformQuotaUsage 的 Lua INCR 可正确累加到新窗口。 +func TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache(t *testing.T) { + daily := 5.0 + // 远古窗口起始,确保 quotaWindowExpired 返回 true + past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{ + UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, + }} + cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 10.0, // 超限,但窗口已过期 → 应被本地清零后放行 + DailyLimitUSD: &daily, + DailyWindowStart: &past, + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + }} + s := newServiceForPreflight(t, repo, cache) + + // 本次 check 应放行(本地清零后 usage=0 < limit=5) + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if err != nil { + t.Errorf("过期窗口应归零放行, got %v", err) + } + + // 验证 cache entry 已被刷新:usage 清零、limit 保留、window_start 更新到当前窗口 + refreshed := cache.getEntry() + if refreshed == nil { + t.Fatal("窗口过期后 cache entry 不应为 nil(应被 SetCache 覆盖,而非 Delete)") + } + if refreshed.DailyUsageUSD != 0 { + t.Errorf("刷新后 DailyUsageUSD = %v, want 0", refreshed.DailyUsageUSD) + } + if refreshed.DailyLimitUSD == nil || *refreshed.DailyLimitUSD != daily { + t.Errorf("刷新后 DailyLimitUSD = %v, want %v(保留)", refreshed.DailyLimitUSD, daily) + } + if refreshed.SchemaVersion != UserPlatformQuotaCacheSchemaV1 { + t.Errorf("刷新后 SchemaVersion = %d, want V1", refreshed.SchemaVersion) + } + if refreshed.DailyWindowStart == nil || refreshed.DailyWindowStart.Equal(past) { + t.Errorf("刷新后 DailyWindowStart = %v, 应更新到当前窗口而非保留 past=%v", refreshed.DailyWindowStart, past) + } +} + +// ── T5 tests: QueueUpdateUserPlatformQuotaUsage ─────────────────────────────── + +// ── C-NEW-1: monthlyQuotaWindowExpired 30 天滚动测试 ───────────────────────── + +func TestMonthlyQuotaWindowExpired_NilStart(t *testing.T) { + if !monthlyQuotaWindowExpired(nil, time.Now().UTC()) { + t.Error("nil start should be considered expired") + } +} + +func TestMonthlyQuotaWindowExpired_Expired(t *testing.T) { + start := time.Now().UTC().Add(-30 * 24 * time.Hour) + if !monthlyQuotaWindowExpired(&start, time.Now().UTC()) { + t.Error("start exactly 30 days ago should be expired") + } +} + +func TestMonthlyQuotaWindowExpired_Active(t *testing.T) { + start := time.Now().UTC().Add(-29 * 24 * time.Hour) + if monthlyQuotaWindowExpired(&start, time.Now().UTC()) { + t.Error("start 29 days ago should NOT be expired") + } +} + +// TestMonthlyQuotaWindowExpired_CrossMonthBoundary 验证跨自然月时 30 天未满不视为过期。 +func TestMonthlyQuotaWindowExpired_CrossMonthBoundary(t *testing.T) { + // 窗口起始 4 月 20 日;5 月 1 日只过了 11 天,不足 30 天 + start := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) + if monthlyQuotaWindowExpired(&start, now) { + t.Error("11 days into window should NOT be expired (30-day rolling, not calendar month)") + } +} + +// TestNextMonthlyResetFrom 验证 30 天滚动重置时间计算。 +func TestNextMonthlyResetFrom_WithStart(t *testing.T) { + start := time.Date(2026, 5, 1, 10, 0, 0, 0, time.UTC) + want := start.Add(30 * 24 * time.Hour) + now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(&start, now) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom = %v, want %v", got, want) + } +} + +func TestNextMonthlyResetFrom_NilStart(t *testing.T) { + now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(nil, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(nil) = %v, want now+30d=%v", got, want) + } +} + +func TestNextMonthlyResetFrom_NilStart_NotEqualToNow(t *testing.T) { + now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC) + got := nextMonthlyResetFrom(nil, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(nil) = %v, want %v (now+30d)", got, want) + } + if got.Equal(now) { + t.Error("nextMonthlyResetFrom(nil) must not return now (should be now+30d)") + } +} + +// TestNextMonthlyResetFrom_ExpiredStart 验证窗口已过期(now-start >= 30d)时, +// 下次重置时间为 now+30d,而非 start+30d(后者已是过去时间,会让 Retry-After 落回 +// fallback 并触发客户端紧凑重试)。 +func TestNextMonthlyResetFrom_ExpiredStart(t *testing.T) { + start := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) + now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) // 距 start 61 天,已过期 + got := nextMonthlyResetFrom(&start, now) + want := now.Add(30 * 24 * time.Hour) + if !got.Equal(want) { + t.Errorf("nextMonthlyResetFrom(expired) = %v, want now+30d=%v", got, want) + } + if !got.After(now) { + t.Error("expired window 的下次重置必须在 now 之后,不能是过去时间") + } +} + +func TestIncrementUserPlatformQuotaUsage_GuardsAgainstEmpty(t *testing.T) { + fake := &fakeIncrCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + } + + s.IncrementUserPlatformQuotaUsage(1, "", 0.5) // empty platform → noop + s.IncrementUserPlatformQuotaUsage(1, "openai", 0) // zero cost → noop + s.IncrementUserPlatformQuotaUsage(1, "openai", -0.1) // negative → noop + + if len(fake.calls) != 0 { + t.Errorf("expected 0 calls (all guarded), got %d", len(fake.calls)) + } +} + +// ── C-NEW-2: 订阅模式豁免 user×platform quota 检查 ────────────────────────── +// 通过直接调用 checkUserPlatformQuotaEligibility 验证: +// 1. standard 模式下 limit=0 → 拦截 +// 2. 订阅模式豁免通过 isSubscriptionMode 守卫体现 — 逻辑已在 CheckBillingEligibility 里加 !isSubscriptionMode 条件 +// 此处用单元测试直接验证底层 checkUserPlatformQuotaEligibility 的行为(quota 超限确实拦截), +// 而 subscription bypass 逻辑则在 CheckBillingEligibility 中通过条件判断保证,不绕过 sub eligibility 内部复杂依赖。 + +// fakeZeroQuotaCache 模拟 cache 命中且 daily limit=0(quota 耗尽)。 +type fakeZeroQuotaCache struct { + BillingCache + called bool +} + +func (f *fakeZeroQuotaCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) { + f.called = true + daily := 0.0 + entry := &UserPlatformQuotaCacheEntry{ + DailyUsageUSD: 0, + DailyLimitUSD: &daily, + DailyWindowStart: func() *time.Time { t := time.Now().UTC(); return &t }(), + SchemaVersion: UserPlatformQuotaCacheSchemaV1, + } + return entry, true, nil +} + +func (f *fakeZeroQuotaCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error { + return nil +} + +// SetUserPlatformQuotaCache 在 weekly/monthly window_start 为 nil 时,checkUserPlatform... +// 会触发"窗口过期 → SetCache 刷新"分支。fake 用 noop 避免 panic。 +func (f *fakeZeroQuotaCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, _ *UserPlatformQuotaCacheEntry, _ time.Duration) error { + return nil +} + +// GetSubscriptionCache 返回有效订阅(active、未过期、usage 远低于 limit), +// 用于支持 checkSubscriptionEligibility 通过,以便验证 quota 检查不被触发。 +func (f *fakeZeroQuotaCache) GetSubscriptionCache(_ context.Context, _ int64, _ int64) (*SubscriptionCacheData, error) { + return &SubscriptionCacheData{ + Status: SubscriptionStatusActive, + ExpiresAt: time.Now().Add(30 * 24 * time.Hour), + DailyUsage: 0, + WeeklyUsage: 0, + MonthlyUsage: 0, + }, nil +} + +func (f *fakeZeroQuotaCache) GetUserBalanceCache(_ context.Context, _ int64) (float64, bool, error) { + return 100.0, true, nil +} + +// TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero 验证: +// standard 模式下 limit=0 的 platform quota 确实会被拦截(守卫底层逻辑正确)。 +func TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero(t *testing.T) { + fake := &fakeZeroQuotaCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("standard mode with limit=0 should return ErrUserPlatformDailyQuotaExhausted, got: %v", err) + } + if !fake.called { + t.Error("GetUserPlatformQuotaCache should have been called in standard mode") + } +} + +// TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota 验证(C-NEW-2): +// 订阅模式用户不受 user×platform quota 拦截,GetUserPlatformQuotaCache 不应被调用。 +func TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota(t *testing.T) { + fake := &fakeZeroQuotaCache{} // GetUserPlatformQuotaCache 返回 limit=0,若被调用则拦截 + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: fake, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + + subGroup := &Group{ + ID: 10, + SubscriptionType: "subscription", + Status: "active", + // 无 DailyLimitUSD → checkSubscriptionEligibility 不会因超限失败 + } + sub := &UserSubscription{Status: "active"} + user := &User{ID: 42} + + err := s.CheckBillingEligibility(context.Background(), user, nil, subGroup, sub, "anthropic") + // 订阅模式下不应收到任何 user×platform quota 错误 + if errors.Is(err, ErrUserPlatformDailyQuotaExhausted) || + errors.Is(err, ErrUserPlatformWeeklyQuotaExhausted) || + errors.Is(err, ErrUserPlatformMonthlyQuotaExhausted) { + t.Errorf("subscription mode should bypass user×platform quota, got: %v", err) + } + // GetUserPlatformQuotaCache 不应被调用 + if fake.called { + t.Error("GetUserPlatformQuotaCache must NOT be called in subscription mode (C-NEW-2)") + } +} + +// TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota 验证: +// 非订阅模式(group=nil)用户 platform quota 超限时被拦截,quota cache 被查询。 +func TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota(t *testing.T) { + called := &fakeZeroQuotaCache{} + cfg := &config.Config{} + cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60 + s := &BillingCacheService{ + cache: called, + cfg: cfg, + userPlatformQuotaRepo: &fakeQuotaRepo{}, + } + err := s.checkUserPlatformQuotaEligibility(context.Background(), 99, "openai") + if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) { + t.Errorf("non-subscription mode quota check should block, got: %v", err) + } + if !called.called { + t.Error("GetUserPlatformQuotaCache should be consulted in non-subscription mode") + } +} + +// ── B-3: monthlyQuotaWindowExpired 30 天边界表驱动测试 ──────────────────────── +// 覆盖 4 个必须场景: +// 1. 恰好 30 天 → expired +// 2. 30*24h - 1ns → not expired +// 3. 跨月末(2024-02-28 → 2024-03-29T00:00:01Z)→ expired +// 4. 跨年(2024-12-15 → 2025-01-14T00:00:01Z)→ expired +// +// repo 层 monthlyMaybeReset 不可导出,通过 service 层 monthlyQuotaWindowExpired 间接覆盖。 +func TestMonthlyQuotaWindowExpired_BoundaryTable(t *testing.T) { + const thirtyDays = 30 * 24 * time.Hour + + cases := []struct { + name string + start time.Time + now time.Time + expired bool + }{ + { + name: "exactly 30 days → expired", + start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays), + expired: true, + }, + { + name: "30d minus 1ns → not expired", + start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays - 1), + expired: false, + }, + { + name: "cross month-end (Feb→Mar, 29d+1s) → expired", + start: time.Date(2024, 2, 28, 0, 0, 0, 0, time.UTC), + now: time.Date(2024, 3, 29, 0, 0, 1, 0, time.UTC), + expired: true, + }, + { + name: "cross year boundary (Dec→Jan, 30d+1s) → expired", + start: time.Date(2024, 12, 15, 0, 0, 0, 0, time.UTC), + now: time.Date(2025, 1, 14, 0, 0, 1, 0, time.UTC), + expired: true, + }, + } + + for _, tc := range cases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + got := monthlyQuotaWindowExpired(&tc.start, tc.now) + if got != tc.expired { + t.Errorf("monthlyQuotaWindowExpired(start=%v, now=%v) = %v, want %v", + tc.start, tc.now, got, tc.expired) + } + }) + } +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 47975c8c..373502cf 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/config" ) @@ -20,6 +21,32 @@ type APIKeyRateLimitCacheData struct { Window7d int64 `json:"window_7d"` } +// UserPlatformQuotaCacheEntry Redis hash 反序列化结果。 +// +// SchemaVersion 用于向后兼容: +// - 0(旧 entry,无 SchemaVersion 字段)→ 视为 cache MISS,强制 refresh +// - 1(当前版本)→ 包含 limits 和 window_start,可免 DB 查询 +// +// limit 字段为 nil 表示"无限额"(DB 中对应列为 NULL)。 +const UserPlatformQuotaCacheSchemaV1 = int64(1) + +type UserPlatformQuotaCacheEntry struct { + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + Version int64 + SchemaVersion int64 + + // 以下字段仅在 SchemaVersion >= 1 时有效 + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations @@ -39,6 +66,13 @@ type BillingCache interface { SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error + + // user × platform quota 缓存 + GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) + SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error + DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error + // IncrUserPlatformQuotaUsageCache 在缓存命中时累加用量;缓存未命中(key 不存在)静默返回 nil。 + IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index c7fa425a..59c34eaa 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -1,6 +1,10 @@ package service -import "github.com/Wei-Shaw/sub2api/internal/domain" +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/domain" +) // Status constants const ( @@ -39,6 +43,26 @@ const ( PlatformAntigravity = domain.PlatformAntigravity ) +// AllowedQuotaPlatforms 是允许设置 user × platform quota 的平台列表(单一权威来源)。 +// ent/schema/user_platform_quota.go 的 Validate 函数独立维护(构建期约束), +// 若新增平台需同步修改该 schema。 +var AllowedQuotaPlatforms = []string{ + PlatformAnthropic, + PlatformOpenAI, + PlatformGemini, + PlatformAntigravity, +} + +// IsAllowedQuotaPlatform 报告 s 是否为合法的 quota platform 标识。 +func IsAllowedQuotaPlatform(s string) bool { + for _, p := range AllowedQuotaPlatforms { + if p == s { + return true + } + } + return false +} + // Account type constants const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) @@ -424,5 +448,15 @@ const ( SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置 ) +// SettingKeyDefaultPlatformQuotas —— 系统全局:每用户 × 平台日/周/月 USD 上限(JSON)。 +// 值为 map[platform]{daily,weekly,monthly},null/缺省 = 不限制;0 = 禁用;>0 = USD 上限。 +const SettingKeyDefaultPlatformQuotas = "default_platform_quotas" + +// SettingKeyAuthSourcePlatformQuotas 返回某 auth source 的 platform quota JSON key。 +// 形如 auth_source_default_{source}_platform_quotas +func SettingKeyAuthSourcePlatformQuotas(source string) string { + return fmt.Sprintf("auth_source_default_%s_platform_quotas", source) +} + // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). const AdminAPIKeyPrefix = "admin-" diff --git a/backend/internal/service/domain_constants_test.go b/backend/internal/service/domain_constants_test.go new file mode 100644 index 00000000..2157c979 --- /dev/null +++ b/backend/internal/service/domain_constants_test.go @@ -0,0 +1,23 @@ +//go:build unit + +package service + +import "testing" + +// TestSettingKeyDefaultPlatformQuotas 验证新的系统层 JSON key 常量值正确。 +func TestSettingKeyDefaultPlatformQuotas(t *testing.T) { + if SettingKeyDefaultPlatformQuotas != "default_platform_quotas" { + t.Errorf("SettingKeyDefaultPlatformQuotas = %q, want %q", + SettingKeyDefaultPlatformQuotas, "default_platform_quotas") + } +} + +// TestSettingKeyAuthSourcePlatformQuotas 验证新的 auth-source JSON key 函数返回值正确。 +func TestSettingKeyAuthSourcePlatformQuotas(t *testing.T) { + if got := SettingKeyAuthSourcePlatformQuotas("email"); got != "auth_source_default_email_platform_quotas" { + t.Fatalf("got %q, want %q", got, "auth_source_default_email_platform_quotas") + } + if got := SettingKeyAuthSourcePlatformQuotas("dingtalk"); got != "auth_source_default_dingtalk_platform_quotas" { + t.Fatalf("got %q, want %q", got, "auth_source_default_dingtalk_platform_quotas") + } +} diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 09b07a5e..f8d98e55 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, // userPlatformQuotaRepo ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 1b1b9c5e..6106391b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -95,6 +95,16 @@ var ( modelsListCacheHitTotal atomic.Int64 modelsListCacheMissTotal atomic.Int64 modelsListCacheStoreTotal atomic.Int64 + + // userPlatformQuotaDBIncrErrorTotal 统计 finalizePostUsageBilling 异步 goroutine + // 中 IncrementUsageWithReset 失败次数。Redis 已成功累加 + DB 写失败意味着 + // Redis cache TTL 过期或被清后该笔 cost 会丢失(与实际消费偏差)。 + // oncall 通过 GatewayUserPlatformQuotaIncrStats() 暴露给 ops 面板做阈值告警。 + userPlatformQuotaDBIncrErrorTotal atomic.Int64 + // userPlatformQuotaDBIncrLegacyErrorTotal 统计 legacy postUsageBilling + // (applyUsageBilling 在 repo==nil 时 fallback)路径下的失败次数; + // 与 DB Incr 失败分开计数,便于区分"主路径暂时故障"vs"基础设施长期未配齐"。 + userPlatformQuotaDBIncrLegacyErrorTotal atomic.Int64 ) func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { @@ -117,6 +127,15 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +// GatewayUserPlatformQuotaIncrStats 返回 (mainPathErr, legacyPathErr)。 +// mainPathErr:finalizePostUsageBilling 异步 goroutine 写 DB 失败累计次数; +// legacyPathErr:postUsageBilling fallback 路径写 DB 失败累计次数。 +// ops 监控面板可以按"持续上升斜率"做告警阈值。 +func GatewayUserPlatformQuotaIncrStats() (mainPathErr, legacyPathErr int64) { + return userPlatformQuotaDBIncrErrorTotal.Load(), + userPlatformQuotaDBIncrLegacyErrorTotal.Load() +} + func openAIStreamEventIsTerminal(data string) bool { trimmed := strings.TrimSpace(data) if trimmed == "" { @@ -575,6 +594,7 @@ type GatewayService struct { debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService balanceNotifyService *BalanceNotifyService + userPlatformQuotaRepo UserPlatformQuotaRepository } // NewGatewayService creates a new GatewayService @@ -605,41 +625,43 @@ func NewGatewayService( channelService *ChannelService, resolver *ModelPricingResolver, balanceNotifyService *BalanceNotifyService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) svc := &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - usageBillingRepo: usageBillingRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - digestStore: digestStore, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, - rpmCache: rpmCache, - userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), - settingService: settingService, - modelsListCache: gocache.New(modelsListTTL, time.Minute), - modelsListCacheTTL: modelsListTTL, - responseHeaderFilter: compileResponseHeaderFilter(cfg), - tlsFPProfileService: tlsFPProfileService, - channelService: channelService, - resolver: resolver, - balanceNotifyService: balanceNotifyService, + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + settingService: settingService, + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + tlsFPProfileService: tlsFPProfileService, + channelService: channelService, + resolver: resolver, + balanceNotifyService: balanceNotifyService, + userPlatformQuotaRepo: userPlatformQuotaRepo, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -7949,6 +7971,7 @@ type RecordUsageInput struct { RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform) ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } @@ -7978,6 +8001,31 @@ type postUsageBillingParams struct { IsSubscriptionBill bool AccountRateMultiplier float64 APIKeyService APIKeyQuotaUpdater + Platform string // 来自 APIKey 关联 Group 的平台标识 +} + +// PlatformFromAPIKey 从 APIKey 关联的 Group 推导 platform 名称。 +// apiKey 为 nil 或 Group 信息缺失时返回空串(调用方据此 short-circuit quota 累加)。 +// 导出供 handler 层调用。 +func PlatformFromAPIKey(apiKey *APIKey) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return apiKey.Group.Platform +} + +// QuotaPlatform 返回 user×platform 配额计量使用的平台标识。 +// 强制平台路由(如 /antigravity)优先按 ctx 中的 ForcePlatform 计量,否则回退到 +// APIKey 关联 Group 的平台。 +// +// 注意:必须用带 ForcePlatform 的请求 context 调用(如 handler 的 c.Request.Context())。 +// 后扣运行在 worker 池的 background ctx 上没有 ForcePlatform,因此后扣平台由 handler +// 预先算定、经 RecordUsageInput.QuotaPlatform 传入,不要在后扣链路用 worker ctx 调用本函数。 +func QuotaPlatform(ctx context.Context, apiKey *APIKey) string { + if fp, ok := ctx.Value(ctxkey.ForcePlatform).(string); ok && fp != "" { + return fp + } + return PlatformFromAPIKey(apiKey) } func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { @@ -8036,6 +8084,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } } + // Platform quota DB-only 累加(与 finalizePostUsageBilling 行为对齐的兜底): + // - 仅对 standard(余额)模式生效;订阅模式豁免 + // - 直接走 DB,不经 Redis Incr 队列:legacy 路径在 repo==nil(仓库未注入) + // 时被触发,此时整套 billing repo 都不可用,没有"双队列"风险 + // - 失败仅记 ALERT log + counter,不阻断主扣费流程;与正常路径一致 + // + // 历史背景:原 legacy path 完全跳过此累加,导致部署中如果 repo 偶然为 nil + // 时用户消费可绕过 platform quota,存在静默资金风险。 + if !p.IsSubscriptionBill && p.Platform != "" && cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil { + if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(billingCtx, p.User.ID, p.Platform, cost.ActualCost, time.Now().UTC()); err != nil { + userPlatformQuotaDBIncrLegacyErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "ALERT: legacy incr user platform quota DB failed user=%d platform=%s cost=%f: %v", p.User.ID, p.Platform, cost.ActualCost, err) + } + } + // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing // cache updates. The legacy path does DB writes directly; the finalize path // does cache queue + notifications. Notifications are dispatched separately @@ -8159,11 +8222,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog } } - finalizePostUsageBilling(p, deps, result) + finalizePostUsageBilling(billingCtx, p, deps, result) return true, nil } -func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { +func finalizePostUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { if p == nil || p.Cost == nil || deps == nil { return } @@ -8182,6 +8245,32 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + // Platform quota 累加:仅在 standard(余额)模式生效;订阅模式豁免 + // Redis 同步写 + DB 异步持久化: + // - Redis 同步:确保下次 preflight 立即看到最新 usage,把 TOCTOU 超支窗口 + // 限制在并发 in-flight 请求数量内(旧实现的异步入队会让超支无限累积直到 worker 处理) + // - DB 异步:在独立 goroutine 中走 detached context,失败用 ALERT log 触发 oncall 对账 + if !p.IsSubscriptionBill && p.Platform != "" && p.Cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil { + deps.billingCacheService.IncrementUserPlatformQuotaUsage(p.User.ID, p.Platform, p.Cost.ActualCost) + dbCtx, dbCancel := detachUpstreamContext(ctx) + userID, platform, cost := p.User.ID, p.Platform, p.Cost.ActualCost + go func() { + defer func() { + if r := recover(); r != nil { + logger.LegacyPrintf("service.gateway", "ALERT: panic in user platform quota incr goroutine user=%d platform=%s: %v", userID, platform, r) + } + }() + defer dbCancel() + if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(dbCtx, userID, platform, cost, time.Now().UTC()); err != nil { + // 失败计数器:暴露给 GatewayUserPlatformQuotaIncrStats(),由 ops 面板做斜率告警。 + userPlatformQuotaDBIncrErrorTotal.Add(1) + // ALERT 级别:DB 持久化失败意味着 Redis cache 失效后该笔 cost 永久丢失, + // 用户配额视图与实际消费会偏差,oncall 需要据此对账或人工补录。 + logger.LegacyPrintf("service.gateway", "ALERT: incr user platform quota DB failed user=%d platform=%s cost=%f: %v", userID, platform, cost, err) + } + }() + } + // Notification checks run async — all parameters are already captured, // no dependency on the request context or upstream connection. go notifyBalanceLow(p, deps, result) @@ -8287,22 +8376,24 @@ func detachUpstreamContext(ctx context.Context) (context.Context, context.Cancel // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) type billingDeps struct { - accountRepo AccountRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - billingCacheService *BillingCacheService - deferredService *DeferredService - balanceNotifyService *BalanceNotifyService + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService + balanceNotifyService *BalanceNotifyService + userPlatformQuotaRepo UserPlatformQuotaRepository } func (s *GatewayService) billingDeps() *billingDeps { return &billingDeps{ - accountRepo: s.accountRepo, - userRepo: s.userRepo, - userSubRepo: s.userSubRepo, - billingCacheService: s.billingCacheService, - deferredService: s.deferredService, - balanceNotifyService: s.balanceNotifyService, + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + balanceNotifyService: s.balanceNotifyService, + userPlatformQuotaRepo: s.userPlatformQuotaRepo, } } @@ -8360,6 +8451,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu RequestPayloadHash: input.RequestPayloadHash, ForceCacheBilling: input.ForceCacheBilling, APIKeyService: input.APIKeyService, + QuotaPlatform: input.QuotaPlatform, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ EnableClaudePath: true, @@ -8382,6 +8474,7 @@ type RecordUsageLongContextInput struct { LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform) ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } @@ -8401,6 +8494,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * RequestPayloadHash: input.RequestPayloadHash, ForceCacheBilling: input.ForceCacheBilling, APIKeyService: input.APIKeyService, + QuotaPlatform: input.QuotaPlatform, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ LongContextThreshold: input.LongContextThreshold, @@ -8422,6 +8516,7 @@ type recordUsageCoreInput struct { RequestPayloadHash string ForceCacheBilling bool APIKeyService APIKeyQuotaUpdater + QuotaPlatform string ChannelUsageFields } @@ -8519,6 +8614,13 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage return nil } + // 配额平台由 handler 在请求 ctx 内经 QuotaPlatform() 算定并通过 input 传入; + // 后扣运行在 worker 池的 background ctx 上,无法再从 ctx 取 ForcePlatform。 + // 缺省(未设置)时回退到分组平台,保持对其它调用方的兼容。 + quotaPlatform := input.QuotaPlatform + if quotaPlatform == "" { + quotaPlatform = PlatformFromAPIKey(apiKey) + } requestID := usageLog.RequestID _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, @@ -8530,6 +8632,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, + Platform: quotaPlatform, }, s.billingDeps(), s.usageBillingRepo) if billingErr != nil { diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 096f5b10..9769a82e 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -155,6 +155,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, nil, nil, + nil, // userPlatformQuotaRepo ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f312f50d..2f81457f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -343,6 +343,7 @@ type OpenAIGatewayService struct { channelService *ChannelService balanceNotifyService *BalanceNotifyService settingService *SettingService + userPlatformQuotaRepo UserPlatformQuotaRepository openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -387,6 +388,7 @@ func NewOpenAIGatewayService( channelService *ChannelService, balanceNotifyService *BalanceNotifyService, settingService *SettingService, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -418,6 +420,7 @@ func NewOpenAIGatewayService( channelService: channelService, balanceNotifyService: balanceNotifyService, settingService: settingService, + userPlatformQuotaRepo: userPlatformQuotaRepo, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -523,12 +526,13 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ - accountRepo: s.accountRepo, - userRepo: s.userRepo, - userSubRepo: s.userSubRepo, - billingCacheService: s.billingCacheService, - deferredService: s.deferredService, - balanceNotifyService: s.balanceNotifyService, + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + balanceNotifyService: s.balanceNotifyService, + userPlatformQuotaRepo: s.userPlatformQuotaRepo, } } @@ -5634,6 +5638,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, + Platform: PlatformFromAPIKey(apiKey), }, s.billingDeps(), s.usageBillingRepo) return err }() diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index f3936de1..99e92251 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -619,6 +619,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, // userPlatformQuotaRepo ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/post_billing_platform_test.go b/backend/internal/service/post_billing_platform_test.go new file mode 100644 index 00000000..a3ce0676 --- /dev/null +++ b/backend/internal/service/post_billing_platform_test.go @@ -0,0 +1,81 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +func TestPlatformFromAPIKey_NilSafe(t *testing.T) { + if got := PlatformFromAPIKey(nil); got != "" { + t.Errorf("nil APIKey should yield empty string, got %q", got) + } +} + +func TestPlatformFromAPIKey_NilGroup(t *testing.T) { + k := &APIKey{Group: nil} + if got := PlatformFromAPIKey(k); got != "" { + t.Errorf("APIKey with nil Group should yield empty string, got %q", got) + } +} + +func TestPlatformFromAPIKey_DerivesFromGroup(t *testing.T) { + tests := []struct { + name string + platform string + }{ + {"anthropic", "anthropic"}, + {"openai", "openai"}, + {"gemini", "gemini"}, + {"antigravity", "antigravity"}, + {"empty", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &APIKey{ + Group: &Group{Platform: tt.platform}, + } + got := PlatformFromAPIKey(k) + if got != tt.platform { + t.Errorf("PlatformFromAPIKey(%q) = %q, want %q", tt.platform, got, tt.platform) + } + }) + } +} + +// TestQuotaPlatform 锁定配额计量口径:ForcePlatform 路由(如 /antigravity)按 ForcePlatform 计, +// 否则回退到 Group 平台。preflight 与 post-billing 共用此口径,保证一致。 +func TestQuotaPlatform(t *testing.T) { + apiKey := &APIKey{Group: &Group{Platform: PlatformAnthropic}} + + t.Run("no force platform falls back to group platform", func(t *testing.T) { + if got := QuotaPlatform(context.Background(), apiKey); got != PlatformAnthropic { + t.Errorf("QuotaPlatform without force = %q, want %q", got, PlatformAnthropic) + } + }) + + t.Run("force platform overrides group platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity) + if got := QuotaPlatform(ctx, apiKey); got != PlatformAntigravity { + t.Errorf("QuotaPlatform with force = %q, want %q", got, PlatformAntigravity) + } + }) + + t.Run("empty force platform falls back to group platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, "") + if got := QuotaPlatform(ctx, apiKey); got != PlatformAnthropic { + t.Errorf("QuotaPlatform with empty force = %q, want %q", got, PlatformAnthropic) + } + }) + + t.Run("nil api key with force platform returns force platform", func(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity) + if got := QuotaPlatform(ctx, nil); got != PlatformAntigravity { + t.Errorf("QuotaPlatform(nil) with force = %q, want %q", got, PlatformAntigravity) + } + }) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 5eef2c13..e6f0f2bc 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -165,12 +165,20 @@ type SettingService struct { openAICodexUASF singleflight.Group } +// DefaultPlatformQuotaSetting 单 platform 三档限额(nil = 沿用上层;0 = 显式禁用;>0 = 上限) +type DefaultPlatformQuotaSetting struct { + DailyLimitUSD *float64 `json:"daily"` + WeeklyLimitUSD *float64 `json:"weekly"` + MonthlyLimitUSD *float64 `json:"monthly"` +} + type ProviderDefaultGrantSettings struct { Balance float64 Concurrency int Subscriptions []DefaultSubscriptionSetting GrantOnSignup bool GrantOnFirstBind bool + PlatformQuotas map[string]*DefaultPlatformQuotaSetting // key = platform name } type AuthSourceDefaultSettings struct { @@ -185,62 +193,80 @@ type AuthSourceDefaultSettings struct { } type authSourceDefaultKeySet struct { + // source 是 auth source 标识(如 "email"、"github"),仅用于 parse 时 + // slog.Warn 诊断输出,不再参与 key 拼接(platformQuotas 字段已存完整 key)。 + source string balance string concurrency string subscriptions string grantOnSignup string grantOnFirstBind string + platformQuotas string // SettingKeyAuthSourcePlatformQuotas(source) } var ( emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "email", balance: SettingKeyAuthSourceDefaultEmailBalance, concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("email"), } linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "linuxdo", balance: SettingKeyAuthSourceDefaultLinuxDoBalance, concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("linuxdo"), } oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "oidc", balance: SettingKeyAuthSourceDefaultOIDCBalance, concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("oidc"), } weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "wechat", balance: SettingKeyAuthSourceDefaultWeChatBalance, concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("wechat"), } gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "github", balance: SettingKeyAuthSourceDefaultGitHubBalance, concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency, subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("github"), } googleAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "google", balance: SettingKeyAuthSourceDefaultGoogleBalance, concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency, subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("google"), } dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{ + source: "dingtalk", balance: SettingKeyAuthSourceDefaultDingTalkBalance, concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency, subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions, grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, + platformQuotas: SettingKeyAuthSourcePlatformQuotas("dingtalk"), } ) @@ -1804,9 +1830,41 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) + // 系统全局 platform quota:整体替换语义(null/缺省 = 不限制)。 + if settings.DefaultPlatformQuotas != nil { + if err := validateDefaultPlatformQuotaMap(settings.DefaultPlatformQuotas); err != nil { + return nil, err + } + blob, err := json.Marshal(settings.DefaultPlatformQuotas) + if err != nil { + return nil, fmt.Errorf("marshal default platform quotas: %w", err) + } + updates[SettingKeyDefaultPlatformQuotas] = string(blob) + } + return updates, nil } +// validateDefaultPlatformQuotaMap 校验 platform quota map 的合法性: +// 平台名须在 AllowedQuotaPlatforms 白名单内,每个非 nil 上限须 finite 且 >= 0。 +// 系统层和 auth-source 层共用此 helper。 +func validateDefaultPlatformQuotaMap(m map[string]*DefaultPlatformQuotaSetting) error { + for platform, pq := range m { + if !IsAllowedQuotaPlatform(platform) { + return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", fmt.Sprintf("unknown platform %q", platform)) + } + if pq == nil { + continue + } + for _, v := range []*float64{pq.DailyLimitUSD, pq.WeeklyLimitUSD, pq.MonthlyLimitUSD} { + if v != nil && (*v < 0 || math.IsNaN(*v) || math.IsInf(*v, 0)) { + return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", "platform quota limit must be a finite non-negative number") + } + } + } + return nil +} + func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { if settings == nil { return nil, nil @@ -1826,6 +1884,26 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett } } + // 校验各 auth source 的 platform quota map(改动 C:对等系统层校验) + for _, pgs := range []struct { + name string + pq map[string]*DefaultPlatformQuotaSetting + }{ + {"email", settings.Email.PlatformQuotas}, + {"linuxdo", settings.LinuxDo.PlatformQuotas}, + {"oidc", settings.OIDC.PlatformQuotas}, + {"wechat", settings.WeChat.PlatformQuotas}, + {"github", settings.GitHub.PlatformQuotas}, + {"google", settings.Google.PlatformQuotas}, + {"dingtalk", settings.DingTalk.PlatformQuotas}, + } { + if pgs.pq != nil { + if err := validateDefaultPlatformQuotaMap(pgs.pq); err != nil { + return nil, err + } + } + } + updates := make(map[string]string, 36) writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) @@ -2386,6 +2464,13 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut SettingKeyAuthSourceDefaultDingTalkSubscriptions, SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, + SettingKeyAuthSourcePlatformQuotas("email"), + SettingKeyAuthSourcePlatformQuotas("linuxdo"), + SettingKeyAuthSourcePlatformQuotas("oidc"), + SettingKeyAuthSourcePlatformQuotas("wechat"), + SettingKeyAuthSourcePlatformQuotas("github"), + SettingKeyAuthSourcePlatformQuotas("google"), + SettingKeyAuthSourcePlatformQuotas("dingtalk"), SettingKeyForceEmailOnThirdPartySignup, } @@ -3179,6 +3264,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.AccountQuotaNotifyEmails = []NotifyEmailEntry{} } + // 系统层默认 platform quota(修复 Bug B:parseSettings 不填充导致回显恒为 nil) + if raw := settings[SettingKeyDefaultPlatformQuotas]; raw != "" { + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] parseSettings: unmarshal default_platform_quotas failed", "error", err) + } else { + result.DefaultPlatformQuotas = parsed + } + } + return result } @@ -3271,6 +3366,15 @@ func parseProviderDefaultGrantSettings(settings map[string]string, keys authSour result.GrantOnFirstBind = raw == "true" } + if raw := settings[keys.platformQuotas]; raw != "" { + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] parseProviderDefaultGrantSettings: unmarshal auth source platform quotas failed", "source", keys.source, "error", err) + } else { + result.PlatformQuotas = parsed + } + } + return result } @@ -3289,6 +3393,17 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource updates[keys.subscriptions] = string(raw) updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) + + // auth source platform quota:整体替换语义。 + // nil = 请求未携带该字段,跳过写入以保留既有配置(与系统层 buildSystemSettingsUpdates 的 + // DefaultPlatformQuotas nil 守卫一致);非 nil(含空 map)才整体替换。二者语义不可混同。 + if keys.platformQuotas != "" && settings.PlatformQuotas != nil { + blob, err := json.Marshal(settings.PlatformQuotas) + if err != nil { + blob = []byte("{}") + } + updates[keys.platformQuotas] = string(blob) + } } func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { @@ -4493,3 +4608,63 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +// GetDefaultPlatformQuotas 读取系统全局 platform quota JSON key,返回 4 platform x 3 window 的设置。 +// 永远返回包含全部 4 platform key 的 map(值可能为零值/nil 字段,表示"上层未配置 = 不限制")。 +// +// 使用单个 JSON key(default_platform_quotas),一次 DB roundtrip,消除旧 12-KV 格式的 N+1 问题。 +// 容错语义:取值失败或 unmarshal 失败 → 返回补齐 4 key 的空 map(fail-open,注册不被阻断)。 +func (s *SettingService) GetDefaultPlatformQuotas(ctx context.Context) (map[string]*DefaultPlatformQuotaSetting, error) { + out := map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {}, + "openai": {}, + "gemini": {}, + "antigravity": {}, + } + raw, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultPlatformQuotas) + if err != nil || raw == "" { + return out, nil // 无配置 = 全部不限制 + } + parsed := map[string]*DefaultPlatformQuotaSetting{} + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + slog.Warn("[Setting] unmarshal default_platform_quotas failed (fail-open)", "error", err) + return out, nil + } + for _, platform := range AllowedQuotaPlatforms { + if v := parsed[platform]; v != nil { + out[platform] = v + } + } + return out, nil // 补齐 4 platform key,保持与旧实现一致的下游契约 +} + +// GetAuthSourcePlatformQuotas 读取指定 auth source 的 platform quota 覆盖(仅返回有配置的平台,override 语义)。 +func (s *SettingService) GetAuthSourcePlatformQuotas(ctx context.Context, source string) map[string]*DefaultPlatformQuotaSetting { + out := map[string]*DefaultPlatformQuotaSetting{} + raw, err := s.settingRepo.GetValue(ctx, SettingKeyAuthSourcePlatformQuotas(source)) + if err != nil || raw == "" { + return out // 无 override + } + if err := json.Unmarshal([]byte(raw), &out); err != nil { + slog.Warn("[Setting] unmarshal auth source platform quotas failed (fail-open)", "source", source, "error", err) + return map[string]*DefaultPlatformQuotaSetting{} + } + return out // 仅含已配置平台,保持 override 语义 +} + +// mergePlatformQuotaDefaults 按字段级 patch:src 中非 nil 字段覆盖 dst。 +// 区分 nil("未配置",保留 dst)vs &0.0("显式禁用",覆盖 dst 为 0) +func mergePlatformQuotaDefaults(dst, src *DefaultPlatformQuotaSetting) { + if src == nil || dst == nil { + return + } + if src.DailyLimitUSD != nil { + dst.DailyLimitUSD = src.DailyLimitUSD + } + if src.WeeklyLimitUSD != nil { + dst.WeeklyLimitUSD = src.WeeklyLimitUSD + } + if src.MonthlyLimitUSD != nil { + dst.MonthlyLimitUSD = src.MonthlyLimitUSD + } +} diff --git a/backend/internal/service/setting_service_platform_quota_test.go b/backend/internal/service/setting_service_platform_quota_test.go new file mode 100644 index 00000000..557cc5f1 --- /dev/null +++ b/backend/internal/service/setting_service_platform_quota_test.go @@ -0,0 +1,344 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestMergePlatformQuotaDefaults_PatchSemantics(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{ + DailyLimitUSD: &five, + WeeklyLimitUSD: &five, + } + ten := 10.0 + patch := DefaultPlatformQuotaSetting{DailyLimitUSD: &ten} + + mergePlatformQuotaDefaults(&base, &patch) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 10.0 { + t.Errorf("daily not patched: %+v", base.DailyLimitUSD) + } + if base.WeeklyLimitUSD == nil || *base.WeeklyLimitUSD != 5.0 { + t.Errorf("weekly should remain 5.0: %+v", base.WeeklyLimitUSD) + } +} + +func TestMergePlatformQuotaDefaults_ZeroIsExplicitDisable(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{DailyLimitUSD: &five} + zero := 0.0 + patch := DefaultPlatformQuotaSetting{DailyLimitUSD: &zero} + + mergePlatformQuotaDefaults(&base, &patch) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 0 { + t.Errorf("explicit 0 should patch base, got %+v", base.DailyLimitUSD) + } +} + +func TestMergePlatformQuotaDefaults_NilSrcIsNoop(t *testing.T) { + five := 5.0 + base := DefaultPlatformQuotaSetting{DailyLimitUSD: &five} + mergePlatformQuotaDefaults(&base, nil) + if base.DailyLimitUSD == nil || *base.DailyLimitUSD != 5.0 { + t.Errorf("nil src should be no-op: %+v", base.DailyLimitUSD) + } +} + +func floatPtrPQ(v float64) *float64 { return &v } + +func newSettingServiceForPlatformQuotaTest(seed map[string]string) *SettingService { + repo := newMockSettingRepo() + for k, v := range seed { + repo.data[k] = v + } + return NewSettingService(repo, &config.Config{}) +} + +func TestGetDefaultPlatformQuotas_ReturnsFourPlatforms(t *testing.T) { + zero := 0.0 + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + // 新 JSON 格式:anthropic daily=10.5, openai monthly=0, gemini/antigravity 无配置 + SettingKeyDefaultPlatformQuotas: `{"anthropic":{"daily":10.5},"openai":{"monthly":0}}`, + }) + got, err := svc.GetDefaultPlatformQuotas(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // 必须包含全部 4 个 platform key(补齐契约) + for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[platform]; !ok { + t.Errorf("missing platform key: %q", platform) + } + } + // anthropic daily = 10.5 + if v := got["anthropic"].DailyLimitUSD; v == nil || *v != 10.5 { + t.Errorf("anthropic daily want 10.5, got %v", v) + } + // openai monthly = 0(显式禁用) + if v := got["openai"].MonthlyLimitUSD; v == nil || *v != zero { + t.Errorf("openai monthly want 0 (explicit disable), got %v", v) + } + // gemini 无配置 → weekly = nil + if v := got["gemini"].WeeklyLimitUSD; v != nil { + t.Errorf("gemini weekly want nil (not configured), got %v", *v) + } + // antigravity 无配置 → daily = nil + if v := got["antigravity"].DailyLimitUSD; v != nil { + t.Errorf("antigravity daily want nil (not configured), got %v", *v) + } +} + +func TestGetAuthSourcePlatformQuotas_OnlyConfiguredReturned(t *testing.T) { + source := "email" + // 新 JSON 格式:anthropic daily=5, monthly=100;openai weekly=0;gemini/antigravity 无配置 + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas(source): `{"anthropic":{"daily":5,"monthly":100},"openai":{"weekly":0}}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), source) + + // anthropic 有配置 → 在结果中 + anthro, ok := got["anthropic"] + if !ok { + t.Fatal("expected anthropic to be present") + } + if anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("anthropic daily want 5.0, got %v", anthro.DailyLimitUSD) + } + if anthro.MonthlyLimitUSD == nil || *anthro.MonthlyLimitUSD != 100.0 { + t.Errorf("anthropic monthly want 100.0, got %v", anthro.MonthlyLimitUSD) + } + if anthro.WeeklyLimitUSD != nil { + t.Errorf("anthropic weekly not configured, want nil, got %v", *anthro.WeeklyLimitUSD) + } + + // openai weekly=0 → 在结果中 + oai, ok := got["openai"] + if !ok { + t.Fatal("expected openai to be present") + } + if oai.WeeklyLimitUSD == nil || *oai.WeeklyLimitUSD != 0 { + t.Errorf("openai weekly want 0, got %v", oai.WeeklyLimitUSD) + } + + // gemini / antigravity 无配置 → 不在结果中(override 语义) + if _, ok := got["gemini"]; ok { + t.Error("gemini not configured, should be absent from result") + } + if _, ok := got["antigravity"]; ok { + t.Error("antigravity not configured, should be absent from result") + } +} + +func TestGetAuthSourcePlatformQuotas_AllNegativeOrEmpty_NoEntry(t *testing.T) { + source := "linuxdo" + // 新 JSON 格式:未配置任何平台(空 JSON key)→ 返回空 map + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas(source): `{}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), source) + // 空 map → override 语义,无 openai 条目 + if _, ok := got["openai"]; ok { + t.Error("empty JSON object should result in no openai entry") + } + if len(got) != 0 { + t.Errorf("expected empty result map, got %v", got) + } +} + +// TestSystemPlatformQuotas_WriteReadRoundTrip 验证系统层 platform quota 经 buildSystemSettingsUpdates(写) +// 再由 GetDefaultPlatformQuotas(读)正确往返——覆盖真实 write→read 路径,锁住 4-key 补齐契约。 +func TestSystemPlatformQuotas_WriteReadRoundTrip(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + ctx := context.Background() + + ten := 10.0 + ss := &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten, WeeklyLimitUSD: nil, MonthlyLimitUSD: nil}, + }, + } + if err := svc.UpdateSettings(ctx, ss); err != nil { + t.Fatalf("UpdateSettings: %v", err) + } + + got, err := svc.GetDefaultPlatformQuotas(ctx) + if err != nil { + t.Fatal(err) + } + // 4-key 补齐契约:无论写了几个 platform,读回必须含全部 4 个 + for _, p := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[p]; !ok { + t.Errorf("4-key contract violated: missing platform %q", p) + } + } + // 写入值正确往返 + if v := got["anthropic"].DailyLimitUSD; v == nil || *v != ten { + t.Fatalf("anthropic daily round-trip failed: got %v, want 10", v) + } + // 未写入的平台字段为 nil + if got["openai"].DailyLimitUSD != nil { + t.Errorf("openai daily should be nil (not written), got %v", got["openai"].DailyLimitUSD) + } +} + +// TestSystemPlatformQuotas_EmptyMapClearsAll 验证空 map 的整体替换语义: +// 写入 DefaultPlatformQuotas={} 后,GetDefaultPlatformQuotas 返回 4 个平台、所有字段均为 nil, +// 明确文档化"空 map = 清空全部配额"是有意为之的 whole-replace 语义。 +func TestSystemPlatformQuotas_EmptyMapClearsAll(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + ctx := context.Background() + + // 先写入有值的配置 + ten := 10.0 + if err := svc.UpdateSettings(ctx, &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &ten}, + }, + }); err != nil { + t.Fatalf("initial write: %v", err) + } + + // 再写入空 map(整体替换语义:清空全部) + if err := svc.UpdateSettings(ctx, &SystemSettings{ + DefaultPlatformQuotas: map[string]*DefaultPlatformQuotaSetting{}, + }); err != nil { + t.Fatalf("empty map write: %v", err) + } + + got, err := svc.GetDefaultPlatformQuotas(ctx) + if err != nil { + t.Fatal(err) + } + // 4 个 key 仍然存在(补齐契约) + for _, p := range []string{"anthropic", "openai", "gemini", "antigravity"} { + if _, ok := got[p]; !ok { + t.Errorf("4-key contract violated after empty write: missing %q", p) + } + } + // 所有字段 nil(全部已清空) + for _, p := range AllowedQuotaPlatforms { + pq := got[p] + if pq == nil { + continue + } + if pq.DailyLimitUSD != nil || pq.WeeklyLimitUSD != nil || pq.MonthlyLimitUSD != nil { + t.Errorf("platform %q should have all-nil limits after empty-map write, got %+v", p, pq) + } + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_PlatformQuotaRoundTrip 验证 round-4 fix: +// PUT /admin/settings 携带的 auth source × platform × window 限额能完整写入并被 GetAuthSourcePlatformQuotas 读回。 +// Round-4 之前 writeProviderDefaultGrantUpdates 完全没写 PQ key,前端配置静默丢失。 +func TestUpdateSettingsWithAuthSourceDefaults_PlatformQuotaRoundTrip(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + systemSettings := &SystemSettings{} + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": { + DailyLimitUSD: floatPtrPQ(5.0), + WeeklyLimitUSD: nil, // 无限额 + MonthlyLimitUSD: floatPtrPQ(100.0), + }, + "openai": { + DailyLimitUSD: floatPtrPQ(0), // 显式禁用 + }, + }, + }, + } + if err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), systemSettings, authDefaults); err != nil { + t.Fatalf("UpdateSettingsWithAuthSourceDefaults: %v", err) + } + got := svc.GetAuthSourcePlatformQuotas(context.Background(), "email") + anthro := got["anthropic"] + if anthro == nil || anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("anthropic daily round-trip failed: %+v", anthro) + } + if anthro != nil && anthro.WeeklyLimitUSD != nil { + t.Errorf("anthropic weekly want nil (无限额), got %v", *anthro.WeeklyLimitUSD) + } + if anthro == nil || anthro.MonthlyLimitUSD == nil || *anthro.MonthlyLimitUSD != 100.0 { + t.Errorf("anthropic monthly round-trip failed: %+v", anthro) + } + oai := got["openai"] + if oai == nil || oai.DailyLimitUSD == nil || *oai.DailyLimitUSD != 0 { + t.Errorf("openai daily=0 (禁用) round-trip failed: %+v", oai) + } + // 其他 source 不应有 quota(authDefaults 只填了 Email) + if linux := svc.GetAuthSourcePlatformQuotas(context.Background(), "linuxdo"); len(linux) != 0 { + t.Errorf("linuxdo should be empty, got %+v", linux) + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_NilPlatformQuotaPreservesExisting 验证 #2 防御: +// 请求未携带某 auth source 的 platform quota(nil)时跳过写入、保留既有配置, +// 而非整体替换为空 map 清空(与系统层 nil 守卫一致)。 +func TestUpdateSettingsWithAuthSourceDefaults_NilPlatformQuotaPreservesExisting(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas("email"): `{"anthropic":{"daily":5,"weekly":null,"monthly":null}}`, + }) + // authDefaults 不携带 Email 的 PlatformQuotas(nil)——应保留既有配置 + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{PlatformQuotas: nil}, + } + if err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), &SystemSettings{}, authDefaults); err != nil { + t.Fatalf("UpdateSettingsWithAuthSourceDefaults: %v", err) + } + anthro := svc.GetAuthSourcePlatformQuotas(context.Background(), "email")["anthropic"] + if anthro == nil || anthro.DailyLimitUSD == nil || *anthro.DailyLimitUSD != 5.0 { + t.Errorf("nil PlatformQuotas 应保留既有 anthropic daily=5,got %+v", anthro) + } +} + +// TestGetAuthSourcePlatformQuotas_JSON 验证新 JSON key 读写语义: +// 写入 JSON,断言已配置平台在结果中、未配置平台不在结果中(override 语义)。 +func TestGetAuthSourcePlatformQuotas_JSON(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(map[string]string{ + SettingKeyAuthSourcePlatformQuotas("email"): `{"openai":{"daily":null,"weekly":null,"monthly":20}}`, + }) + got := svc.GetAuthSourcePlatformQuotas(context.Background(), "email") + + // openai monthly = 20 + oai, ok := got["openai"] + if !ok { + t.Fatal("expected openai to be present") + } + if oai.MonthlyLimitUSD == nil || *oai.MonthlyLimitUSD != 20 { + t.Errorf("openai monthly want 20, got %v", oai.MonthlyLimitUSD) + } + if oai.DailyLimitUSD != nil { + t.Errorf("openai daily want nil, got %v", *oai.DailyLimitUSD) + } + if oai.WeeklyLimitUSD != nil { + t.Errorf("openai weekly want nil, got %v", *oai.WeeklyLimitUSD) + } + + // anthropic 未配置 → 不在结果中(override 语义) + if _, ok := got["anthropic"]; ok { + t.Error("anthropic not configured, should be absent from result") + } +} + +// TestUpdateSettingsWithAuthSourceDefaults_NegativeQuotaRejected 验证改动 C: +// auth-source platform quota 含负数时,UpdateSettingsWithAuthSourceDefaults 返回 BadRequest 错误。 +func TestUpdateSettingsWithAuthSourceDefaults_NegativeQuotaRejected(t *testing.T) { + svc := newSettingServiceForPlatformQuotaTest(nil) + neg := -1.0 + authDefaults := &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{ + "anthropic": {DailyLimitUSD: &neg}, + }, + }, + } + err := svc.UpdateSettingsWithAuthSourceDefaults(context.Background(), &SystemSettings{}, authDefaults) + require.Error(t, err, "expected error for negative quota") + require.Equal(t, "INVALID_DEFAULT_PLATFORM_QUOTA", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index c9bea224..3f961ab2 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -219,6 +219,9 @@ type SystemSettings struct { // 账号限额通知 AccountQuotaNotifyEnabled bool AccountQuotaNotifyEmails []NotifyEmailEntry + + // 系统全局默认平台配额(key = platform,nil/缺省 = 不限制) + DefaultPlatformQuotas map[string]*DefaultPlatformQuotaSetting `json:"default_platform_quotas"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/service/user_platform_quota_port.go b/backend/internal/service/user_platform_quota_port.go new file mode 100644 index 00000000..cb09542a --- /dev/null +++ b/backend/internal/service/user_platform_quota_port.go @@ -0,0 +1,50 @@ +package service + +import ( + "context" + "errors" + "time" +) + +// ErrUserPlatformQuotaNotFound service 层 sentinel:quota 记录不存在。 +// adapter 将 repository.ErrUserPlatformQuotaNotFound 包装为此错误, +// handler 只需引用 service 包,无需直接依赖 repository 包。 +var ErrUserPlatformQuotaNotFound = errors.New("user platform quota not found") + +// UserPlatformQuotaRecord service 层传输结构体(与 repository 层解耦)。 +type UserPlatformQuotaRecord struct { + UserID int64 + Platform string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + // 窗口起始时间(可选,用于未来 reset 校验) + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time +} + +// UserPlatformQuotaRepository 定义 service 层所需的 user × platform quota 数据访问端口。 +// repository 包的 userPlatformQuotaRepository 实现此接口。 +type UserPlatformQuotaRepository interface { + // GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。 + GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) + // BulkInsertInitial 幂等批量插入初始配额记录(ON CONFLICT DO NOTHING)。 + BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error + // IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。 + IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error + // ListByUser 查询用户的所有平台配额记录。 + ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) + // UpsertForUser 全量替换该用户所有平台限额配置(事务内): + // 1. 软删除未在 records 中出现的所有 active 行 + // 2. 对 records 中每条:UPDATE 已存在的(含重新激活软删行);UPDATE 未命中时 INSERT + // 仅改 *_limit_usd + deleted_at + updated_at,保留 *_usage_usd / *_window_start。 + // records 为空时仅执行步骤 1。 + UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error + // ResetExpiredWindow 重置指定窗口("daily"|"weekly"|"monthly")的用量与起始时间。 + // 未命中活跃记录时返回(service-side wrapper of repository.ErrUserPlatformQuotaNotFound)。 + ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error +} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 775dd602..19aec5d3 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -202,7 +202,7 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { out := make([]UserAuthIdentityRecord, len(m.identities)) copy(out, m.identities) @@ -315,6 +315,22 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err return nil } +func (m *mockBillingCache) GetUserPlatformQuotaCache(context.Context, int64, string) (*UserPlatformQuotaCacheEntry, bool, error) { + return nil, false, nil +} + +func (m *mockBillingCache) SetUserPlatformQuotaCache(context.Context, int64, string, *UserPlatformQuotaCacheEntry, time.Duration) error { + return nil +} + +func (m *mockBillingCache) DeleteUserPlatformQuotaCache(context.Context, int64, string) error { + return nil +} + +func (m *mockBillingCache) IncrUserPlatformQuotaUsageCache(context.Context, int64, string, float64, time.Duration) error { + return nil +} + // --- 测试 --- func TestUpdateBalance_Success(t *testing.T) { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 43748aa5..b22e10ae 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -417,8 +417,9 @@ func ProvideBillingCacheService( rpmCache UserRPMCache, rateRepo UserGroupRateRepository, cfg *config.Config, + userPlatformQuotaRepo UserPlatformQuotaRepository, ) *BillingCacheService { - return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) + return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg, userPlatformQuotaRepo) } // ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation. diff --git a/backend/migrations/142_user_platform_quotas.sql b/backend/migrations/142_user_platform_quotas.sql new file mode 100644 index 00000000..34375edd --- /dev/null +++ b/backend/migrations/142_user_platform_quotas.sql @@ -0,0 +1,36 @@ +-- 用户平台维度配额表。每个 (user_id, platform) 对独立记录日/周/月三层 USD 限额与用量。 +-- nil limit = 不限制(沿用上层默认),0 = 显式禁用,>0 = USD 上限。 +-- 软删除:deleted_at IS NULL 的记录为活跃记录,部分唯一索引保证同用户同平台只有一条活跃配额。 + +CREATE TABLE IF NOT EXISTS user_platform_quotas ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + platform VARCHAR(32) NOT NULL CHECK (platform IN ('anthropic', 'openai', 'gemini', 'antigravity')), + + -- 日 / 周 / 月 USD 上限:NULL = 不限制,0 = 显式禁用,>0 = 上限 + daily_limit_usd DECIMAL(20,10), + weekly_limit_usd DECIMAL(20,10), + monthly_limit_usd DECIMAL(20,10), + + -- 当前窗口已用量 + daily_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + weekly_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + monthly_usage_usd DECIMAL(20,10) NOT NULL DEFAULT 0, + + -- 窗口起点(NULL = 首次尚未初始化) + daily_window_start TIMESTAMPTZ, + weekly_window_start TIMESTAMPTZ, + monthly_window_start TIMESTAMPTZ, + + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + deleted_at TIMESTAMPTZ +); + +-- 软删除友好唯一索引:同用户同平台只允许一条未删除记录 +CREATE UNIQUE INDEX IF NOT EXISTS userplatformquota_user_id_platform_uq + ON user_platform_quotas (user_id, platform) + WHERE deleted_at IS NULL; + +CREATE INDEX IF NOT EXISTS userplatformquota_user_id + ON user_platform_quotas (user_id); diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8e9b0e3b..8d8e1fd7 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -1001,6 +1001,9 @@ billing: # Number of requests to allow in half-open state # 半开状态允许通过的请求数 half_open_requests: 3 + # Cache TTL (seconds) for per-user × per-platform quota records + # 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口 + user_platform_quota_cache_ttl_seconds: 86400 # ============================================================================= # Turnstile Configuration diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts index 10f6247a..4dc7db2e 100644 --- a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -3,9 +3,20 @@ import { describe, expect, it } from "vitest"; import { appendAuthSourceDefaultsToUpdateRequest, buildAuthSourceDefaultsState, + normalizePlatformQuotasMap, + sanitizePlatformQuotasMap, type UpdateSettingsRequest, + type DefaultPlatformQuotasMap, } from "@/api/admin/settings"; +/** 全 null 的 4 平台 map,用于断言归一化默认值 */ +const allNullQuotas: DefaultPlatformQuotasMap = { + anthropic: { daily: null, weekly: null, monthly: null }, + openai: { daily: null, weekly: null, monthly: null }, + gemini: { daily: null, weekly: null, monthly: null }, + antigravity: { daily: null, weekly: null, monthly: null }, +} + describe("admin settings auth source defaults helpers", () => { it("builds auth source defaults state from flat settings fields", () => { const state = buildAuthSourceDefaultsState({ @@ -31,6 +42,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 1, validity_days: 30 }], grant_on_signup: false, grant_on_first_bind: true, + platform_quotas: allNullQuotas, }); expect(state.linuxdo).toEqual({ balance: 6, @@ -38,6 +50,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 2, validity_days: 60 }], grant_on_signup: true, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); expect(state.oidc).toEqual({ balance: 0, @@ -45,6 +58,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); expect(state.wechat).toEqual({ balance: 0, @@ -52,6 +66,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: allNullQuotas, }); }); @@ -64,6 +79,23 @@ describe("admin settings auth source defaults helpers", () => { expect(state.wechat.grant_on_signup).toBe(false); }); + it("reads nested platform_quotas from settings into auth source state", () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_platform_quotas: { + anthropic: { daily: 10, weekly: 50, monthly: 200 }, + openai: { daily: null, weekly: null, monthly: null }, + } as DefaultPlatformQuotasMap, + }); + + // anthropic 填写的值应被保留 + expect(state.email.platform_quotas.anthropic).toEqual({ daily: 10, weekly: 50, monthly: 200 }); + // openai 全 null 应被保留 + expect(state.email.platform_quotas.openai).toEqual({ daily: null, weekly: null, monthly: null }); + // 未出现的平台(gemini/antigravity)归一化为 null + expect(state.email.platform_quotas.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(state.email.platform_quotas.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); + it("appends auth source defaults back onto update payload", () => { const payload: UpdateSettingsRequest = { site_name: "Sub2API", @@ -76,6 +108,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 3, validity_days: 7 }], grant_on_signup: true, grant_on_first_bind: false, + platform_quotas: {}, }, linuxdo: { balance: 0, @@ -83,6 +116,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: true, + platform_quotas: {}, }, oidc: { balance: 4, @@ -90,6 +124,7 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [{ group_id: 9, validity_days: 90 }], grant_on_signup: true, grant_on_first_bind: true, + platform_quotas: {}, }, wechat: { balance: 2, @@ -97,6 +132,31 @@ describe("admin settings auth source defaults helpers", () => { subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, + platform_quotas: {}, + }, + github: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, + }, + google: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, + }, + dingtalk: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: {}, }, }); @@ -126,6 +186,111 @@ describe("admin settings auth source defaults helpers", () => { auth_source_default_wechat_subscriptions: [], auth_source_default_wechat_grant_on_signup: false, auth_source_default_wechat_grant_on_first_bind: false, + // 嵌套 platform_quotas 字段 + auth_source_default_email_platform_quotas: allNullQuotas, + auth_source_default_linuxdo_platform_quotas: allNullQuotas, + auth_source_default_oidc_platform_quotas: allNullQuotas, + auth_source_default_wechat_platform_quotas: allNullQuotas, + auth_source_default_github_platform_quotas: allNullQuotas, + auth_source_default_google_platform_quotas: allNullQuotas, + auth_source_default_dingtalk_platform_quotas: allNullQuotas, }); }); + + it("appends sanitized nested platform_quotas with non-null values in update payload", () => { + const payload: UpdateSettingsRequest = {}; + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + platform_quotas: { + anthropic: { daily: 10, weekly: 50, monthly: 200 }, + openai: { daily: 0, weekly: null, monthly: null }, + }, + }, + linuxdo: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + oidc: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + wechat: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + github: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + google: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + dingtalk: { balance: 0, concurrency: 5, subscriptions: [], grant_on_signup: false, grant_on_first_bind: false, platform_quotas: {} }, + }); + + const emailQuotas = (payload as Record)["auth_source_default_email_platform_quotas"] as DefaultPlatformQuotasMap; + expect(emailQuotas.anthropic).toEqual({ daily: 10, weekly: 50, monthly: 200 }); + // 0 是合法值(不限额=0 与"不设"不同,保留) + expect(emailQuotas.openai?.daily).toBe(0); + // 缺失平台归一化为全 null + expect(emailQuotas.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(emailQuotas.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); +}); + +describe("normalizePlatformQuotasMap", () => { + it("填充缺失的平台为全 null 三档", () => { + const result = normalizePlatformQuotasMap({ anthropic: { daily: 5, weekly: null, monthly: null } }); + expect(result.anthropic).toEqual({ daily: 5, weekly: null, monthly: null }); + expect(result.openai).toEqual({ daily: null, weekly: null, monthly: null }); + expect(result.gemini).toEqual({ daily: null, weekly: null, monthly: null }); + expect(result.antigravity).toEqual({ daily: null, weekly: null, monthly: null }); + }); + + it("无参数时返回全 4 平台全 null", () => { + const result = normalizePlatformQuotasMap(); + expect(Object.keys(result)).toHaveLength(4); + for (const v of Object.values(result)) { + expect(v).toEqual({ daily: null, weekly: null, monthly: null }); + } + }); + + it("非 number 类型的值归一化为 null", () => { + const result = normalizePlatformQuotasMap({ + anthropic: { daily: "50" as unknown as number, weekly: undefined as unknown as number, monthly: null }, + }); + expect(result.anthropic).toEqual({ daily: null, weekly: null, monthly: null }); + }); +}); + +describe("sanitizePlatformQuotasMap", () => { + it("保留合法的正数和零值", () => { + const result = sanitizePlatformQuotasMap({ + anthropic: { daily: 10.5, weekly: 0, monthly: null }, + }); + expect(result.anthropic?.daily).toBe(10.5); + expect(result.anthropic?.weekly).toBe(0); + expect(result.anthropic?.monthly).toBe(null); + }); + + it("空字符串(v-model.number 空输入)清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + anthropic: { daily: "" as unknown as number, weekly: null, monthly: null }, + }); + expect(result.anthropic?.daily).toBe(null); + }); + + it("负数清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + openai: { daily: -1, weekly: null, monthly: null }, + }); + expect(result.openai?.daily).toBe(null); + }); + + it("NaN/Infinity 清洗为 null", () => { + const result = sanitizePlatformQuotasMap({ + gemini: { daily: NaN, weekly: Infinity, monthly: null }, + }); + expect(result.gemini?.daily).toBe(null); + expect(result.gemini?.weekly).toBe(null); + }); + + it("缺失平台填充为全 null", () => { + const result = sanitizePlatformQuotasMap({}); + expect(Object.keys(result)).toHaveLength(4); + for (const v of Object.values(result)) { + expect(v).toEqual({ daily: null, weekly: null, monthly: null }); + } + }); }); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 3bdf3ee4..d2b878cc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -16,6 +16,47 @@ export interface DefaultSubscriptionSetting { validity_days: number; } +// ── 平台限额类型 ────────────────────────────────────────────────── +export type PlatformType = "anthropic" | "openai" | "gemini" | "antigravity" +export type QuotaWindowType = "daily" | "weekly" | "monthly" + +/** 单平台三档限额;null = 不限制,undefined = 未填(等价 null) */ +export interface PlatformQuotaLimits { + daily: number | null + weekly: number | null + monthly: number | null +} + +/** 全平台默认限额 map(key = PlatformType) */ +export type DefaultPlatformQuotasMap = Partial> + +const PLATFORMS: PlatformType[] = ["anthropic", "openai", "gemini", "antigravity"] + +/** 归一化为全 4 平台 × 3 窗口(缺失填 null),供模板非空绑定 */ +export function normalizePlatformQuotasMap(input?: DefaultPlatformQuotasMap | null): DefaultPlatformQuotasMap { + const result: DefaultPlatformQuotasMap = {} + for (const p of PLATFORMS) { + const src = input?.[p] + result[p] = { + daily: typeof src?.daily === "number" ? src.daily : null, + weekly: typeof src?.weekly === "number" ? src.weekly : null, + monthly: typeof src?.monthly === "number" ? src.monthly : null, + } + } + return result +} + +/** 提交前清洗:非有限数/负数/空字符串 → null(保留 0 = 显式禁用),返回全 4 平台嵌套 map */ +export function sanitizePlatformQuotasMap(input?: DefaultPlatformQuotasMap | null): DefaultPlatformQuotasMap { + const clean = (v: unknown): number | null => (typeof v === "number" && Number.isFinite(v) && v >= 0 ? v : null) + const result: DefaultPlatformQuotasMap = {} + for (const p of PLATFORMS) { + const src = input?.[p] + result[p] = { daily: clean(src?.daily), weekly: clean(src?.weekly), monthly: clean(src?.monthly) } + } + return result +} + export type AuthSourceType = | "email" | "linuxdo" @@ -31,6 +72,8 @@ export interface AuthSourceDefaultsValue { subscriptions: DefaultSubscriptionSetting[]; grant_on_signup: boolean; grant_on_first_bind: boolean; + // ★ 新增:平台限额覆盖(key = PlatformType) + platform_quotas: DefaultPlatformQuotasMap; } export type AuthSourceDefaultsState = Record< @@ -193,6 +236,7 @@ export function buildAuthSourceDefaultsState( raw[`auth_source_default_${source}_grant_on_signup`] === true, grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + platform_quotas: normalizePlatformQuotasMap(raw[`auth_source_default_${source}_platform_quotas`] as DefaultPlatformQuotasMap | undefined), }; return acc; }, {} as AuthSourceDefaultsState); @@ -220,6 +264,7 @@ export function appendAuthSourceDefaultsToUpdateRequest( current.grant_on_signup; target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind; + target[`auth_source_default_${source}_platform_quotas`] = sanitizePlatformQuotasMap(current.platform_quotas) } return payload; @@ -370,6 +415,15 @@ export interface SystemSettings { auth_source_default_google_grant_on_signup?: boolean; auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; + // ── 平台限额(嵌套 JSON,系统层 + 7 auth-source 层)──────────────────────────────── + default_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_email_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_linuxdo_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_oidc_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_wechat_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_github_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_google_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_dingtalk_platform_quotas?: DefaultPlatformQuotasMap; // OEM settings site_name: string; site_logo: string; @@ -616,6 +670,15 @@ export interface UpdateSettingsRequest { auth_source_default_google_grant_on_signup?: boolean; auth_source_default_google_grant_on_first_bind?: boolean; force_email_on_third_party_signup?: boolean; + // ── 平台限额(嵌套 JSON,系统层 + 7 auth-source 层)──────────────────────────────── + default_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_email_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_linuxdo_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_oidc_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_wechat_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_github_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_google_platform_quotas?: DefaultPlatformQuotasMap; + auth_source_default_dingtalk_platform_quotas?: DefaultPlatformQuotasMap; site_name?: string; site_logo?: string; site_subtitle?: string; diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index fabc69bc..bfe5e3ba 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -297,6 +297,78 @@ export async function bindUserAuthIdentity( return data } +/** + * Platform quota types + */ +export type PlatformQuotaPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' +export type PlatformQuotaWindow = 'daily' | 'weekly' | 'monthly' + +export interface PlatformQuotaItem { + platform: PlatformQuotaPlatform + daily_limit_usd: number | null + weekly_limit_usd: number | null + monthly_limit_usd: number | null + daily_usage_usd: number + weekly_usage_usd: number + monthly_usage_usd: number + daily_window_start?: string | null + weekly_window_start?: string | null + monthly_window_start?: string | null + daily_window_resets_at?: string | null + weekly_window_resets_at?: string | null + monthly_window_resets_at?: string | null +} + +export interface PlatformQuotaUpdateItem { + platform: PlatformQuotaPlatform + daily_limit_usd: number | null + weekly_limit_usd: number | null + monthly_limit_usd: number | null +} + +export interface PlatformQuotasResponse { + platform_quotas: PlatformQuotaItem[] +} + +/** + * Get user's platform quotas + */ +export async function getPlatformQuotas(id: number): Promise { + const { data } = await apiClient.get( + `/admin/users/${id}/platform-quotas` + ) + return data +} + +/** + * Replace user's platform quotas (全量替换) + */ +export async function updatePlatformQuotas( + id: number, + quotas: PlatformQuotaUpdateItem[] +): Promise { + const { data } = await apiClient.put( + `/admin/users/${id}/platform-quotas`, + { quotas } + ) + return data +} + +/** + * Reset a single (platform, window) usage immediately + */ +export async function resetPlatformQuotaWindow( + id: number, + platform: PlatformQuotaPlatform, + window: PlatformQuotaWindow +): Promise { + const { data } = await apiClient.post( + `/admin/users/${id}/platform-quotas/reset`, + { platform, window } + ) + return data +} + export const usersAPI = { list, getById, @@ -310,7 +382,10 @@ export const usersAPI = { getUserUsageStats, getUserBalanceHistory, replaceGroup, - bindUserAuthIdentity + bindUserAuthIdentity, + getPlatformQuotas, + updatePlatformQuotas, + resetPlatformQuotaWindow, } export default usersAPI diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index da7a91eb..88768530 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -15,7 +15,8 @@ import type { NotifyEmailEntry, UserAuthProvider, UserAffiliateDetail, - AffiliateTransferResponse + AffiliateTransferResponse, + PlatformQuotasResponse, } from '@/types' /** @@ -185,6 +186,14 @@ export async function transferAffiliateQuota(): Promise { + const { data } = await apiClient.get('/user/platform-quotas') + return data +} + export const userAPI = { getProfile, updateProfile, @@ -199,7 +208,8 @@ export const userAPI = { buildOAuthBindingStartURL, startOAuthBinding, getAffiliateDetail, - transferAffiliateQuota + transferAffiliateQuota, + getMyPlatformQuotas, } export default userAPI diff --git a/frontend/src/components/admin/user/UserPlatformQuotaModal.vue b/frontend/src/components/admin/user/UserPlatformQuotaModal.vue new file mode 100644 index 00000000..76d23b0c --- /dev/null +++ b/frontend/src/components/admin/user/UserPlatformQuotaModal.vue @@ -0,0 +1,283 @@ + + + diff --git a/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts b/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts new file mode 100644 index 00000000..06248133 --- /dev/null +++ b/frontend/src/components/admin/user/__tests__/UserPlatformQuotaModal.spec.ts @@ -0,0 +1,245 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { createPinia, setActivePinia } from 'pinia' + +const apiMocks = vi.hoisted(() => ({ + getPlatformQuotas: vi.fn(), + updatePlatformQuotas: vi.fn(), + resetPlatformQuotaWindow: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + users: { + getPlatformQuotas: apiMocks.getPlatformQuotas, + updatePlatformQuotas: apiMocks.updatePlatformQuotas, + resetPlatformQuotaWindow: apiMocks.resetPlatformQuotaWindow, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError: vi.fn(), + showSuccess: vi.fn(), + }), +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (params) { + return key.replace(/\{(\w+)\}/g, (_, k) => params[k] ?? '') + } + return key + }, + }), + } +}) + +vi.mock('@/components/common/BaseDialog.vue', () => ({ + default: { + name: 'BaseDialog', + props: ['show', 'title', 'width'], + template: '
', + }, +})) + +import UserPlatformQuotaModal from '../UserPlatformQuotaModal.vue' +import type { UserSubscription } from '@/types' + +function makeUser(overrides: { subscriptions?: UserSubscription[] } = {}) { + return { id: 99, email: 'u@example.com', ...overrides } as any +} + +/** 挂载并触发 show:false → true,确保 watch 被激活 */ +async function mountAndOpen(extraProps: Record = {}) { + const w = mount(UserPlatformQuotaModal, { + props: { show: false, user: makeUser(), ...extraProps }, + }) + await w.setProps({ show: true }) + await flushPromises() + return w +} + +beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + apiMocks.getPlatformQuotas.mockResolvedValue({ platform_quotas: [] }) + apiMocks.updatePlatformQuotas.mockResolvedValue({ platform_quotas: [] }) + apiMocks.resetPlatformQuotaWindow.mockResolvedValue({ platform_quotas: [] }) +}) + +describe('UserPlatformQuotaModal', () => { + it('挂载并 show=true 时调用 getPlatformQuotas', async () => { + await mountAndOpen() + expect(apiMocks.getPlatformQuotas).toHaveBeenCalledWith(99) + }) + + it('空数据渲染 4 个 platform 行', async () => { + const w = await mountAndOpen() + const html = w.html() + expect(html).toContain('anthropic') + expect(html).toContain('openai') + expect(html).toContain('gemini') + expect(html).toContain('antigravity') + }) + + it('已有数据正确填充 limit input', async () => { + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: null, monthly_limit_usd: null, + daily_usage_usd: 3.2, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const inputs = w.findAll('input[type=number]') + // 4 platforms × 3 windows = 12 inputs + expect(inputs.length).toBe(12) + // 第一个 input 是 anthropic.daily = 10 + expect((inputs[0].element as HTMLInputElement).value).toBe('10') + }) + + it('保存提交完整 4 platform payload', async () => { + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'openai', daily_limit_usd: null, weekly_limit_usd: 20, monthly_limit_usd: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + // 找到「保存」按钮(包含中文「保存」字样的按钮) + const buttons = w.findAll('button') + const saveBtn = buttons.find((b) => b.text() === 'admin.users.platformQuota.save') + expect(saveBtn).toBeTruthy() + await saveBtn!.trigger('click') + await flushPromises() + expect(apiMocks.updatePlatformQuotas).toHaveBeenCalledTimes(1) + const [uid, payload] = apiMocks.updatePlatformQuotas.mock.calls[0] + expect(uid).toBe(99) + expect(payload).toHaveLength(4) // 4 platforms always submitted + const openai = payload.find((p: any) => p.platform === 'openai') + expect(openai.weekly_limit_usd).toBe(20) + }) + + it('全部清空把所有 limit 置 null(确认通过)', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true) + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: 50, monthly_limit_usd: 100, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const buttons = w.findAll('button') + const clearBtn = buttons.find((b) => b.text() === 'admin.users.platformQuota.clearAll') + expect(clearBtn).toBeTruthy() + await clearBtn!.trigger('click') + await flushPromises() + expect(confirmSpy).toHaveBeenCalledTimes(1) + const inputs = w.findAll('input[type=number]') + for (const inp of inputs) { + expect((inp.element as HTMLInputElement).value).toBe('') + } + confirmSpy.mockRestore() + }) + + it('全部清空 confirm 取消则保持原值', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false) + apiMocks.getPlatformQuotas.mockResolvedValueOnce({ + platform_quotas: [ + { platform: 'anthropic', daily_limit_usd: 10, weekly_limit_usd: 50, monthly_limit_usd: 100, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0 }, + ], + }) + const w = await mountAndOpen() + const clearBtn = w.findAll('button').find((b) => b.text() === 'admin.users.platformQuota.clearAll') + await clearBtn!.trigger('click') + await flushPromises() + expect(confirmSpy).toHaveBeenCalledTimes(1) + // anthropic daily 应保持 10(未被清空) + const inputs = w.findAll('input[type=number]') + const dailyVal = (inputs[0].element as HTMLInputElement).value + expect(dailyVal).toBe('10') + confirmSpy.mockRestore() + }) + + it('重置按钮 confirm 取消则不调用 API', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(false) + const w = await mountAndOpen() + const resetBtns = w.findAll('button').filter((b) => b.text() === '↻') + expect(resetBtns.length).toBeGreaterThan(0) + await resetBtns[0].trigger('click') + await flushPromises() + expect(apiMocks.resetPlatformQuotaWindow).not.toHaveBeenCalled() + confirmSpy.mockRestore() + }) + + it('重置按钮 confirm 确认则调用 API', async () => { + const confirmSpy = vi.spyOn(window, 'confirm').mockReturnValue(true) + const w = await mountAndOpen() + const resetBtns = w.findAll('button').filter((b) => b.text() === '↻') + await resetBtns[0].trigger('click') // 第一个是 anthropic.daily + await flushPromises() + expect(apiMocks.resetPlatformQuotaWindow).toHaveBeenCalledWith(99, 'anthropic', 'daily') + confirmSpy.mockRestore() + }) + + describe('subscription warning banner', () => { + it('displays subscription warning when user has active subscription', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ + subscriptions: [ + { + id: 1, user_id: 99, group_id: 1, status: 'active', + starts_at: '2026-01-01T00:00:00Z', expires_at: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + daily_window_start: null, weekly_window_start: null, monthly_window_start: null, + created_at: '2026-01-01T00:00:00Z', updated_at: '2026-01-01T00:00:00Z', + } as UserSubscription, + ], + }), + }, + }) + await flushPromises() + expect(w.html()).toContain('admin.users.platformQuota.subscriptionWarning') + }) + + it('hides subscription warning when user has only expired subscriptions', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ + subscriptions: [ + { + id: 2, user_id: 99, group_id: 1, status: 'expired', + starts_at: '2025-01-01T00:00:00Z', expires_at: '2025-12-31T00:00:00Z', + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + daily_window_start: null, weekly_window_start: null, monthly_window_start: null, + created_at: '2025-01-01T00:00:00Z', updated_at: '2025-12-31T00:00:00Z', + } as UserSubscription, + ], + }), + }, + }) + await flushPromises() + expect(w.html()).not.toContain('admin.users.platformQuota.subscriptionWarning') + }) + + it('hides subscription warning when subscriptions is empty array', async () => { + const w = mount(UserPlatformQuotaModal, { + props: { + show: true, + user: makeUser({ subscriptions: [] }), + }, + }) + await flushPromises() + expect(w.html()).not.toContain('admin.users.platformQuota.subscriptionWarning') + }) + }) +}) diff --git a/frontend/src/components/user/UserPlatformQuotaCell.vue b/frontend/src/components/user/UserPlatformQuotaCell.vue new file mode 100644 index 00000000..376b61c8 --- /dev/null +++ b/frontend/src/components/user/UserPlatformQuotaCell.vue @@ -0,0 +1,61 @@ + + + diff --git a/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts b/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts new file mode 100644 index 00000000..533d1a32 --- /dev/null +++ b/frontend/src/components/user/__tests__/UserPlatformQuotaCell.spec.ts @@ -0,0 +1,74 @@ +import { describe, it, expect, vi } from 'vitest' +import { mount } from '@vue/test-utils' + +// t() 回显 key,便于断言文案键 +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ t: (key: string) => key }), + } +}) + +import UserPlatformQuotaCell from '../UserPlatformQuotaCell.vue' +import type { PlatformQuotaItem } from '@/api/admin/users' + +function item(over: Partial & { platform: PlatformQuotaItem['platform'] }): PlatformQuotaItem { + return { + daily_limit_usd: null, weekly_limit_usd: null, monthly_limit_usd: null, + daily_usage_usd: 0, weekly_usage_usd: 0, monthly_usage_usd: 0, + ...over, + } as PlatformQuotaItem +} + +describe('UserPlatformQuotaCell', () => { + it('quotas 为 undefined 时渲染加载占位 …', () => { + const w = mount(UserPlatformQuotaCell, { props: { quotas: undefined } }) + expect(w.text()).toContain('…') + expect(w.html()).not.toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('空数组渲染「未配置」', () => { + const w = mount(UserPlatformQuotaCell, { props: { quotas: [] } }) + expect(w.html()).toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('平台有记录但全部 limit 为 null 时视为未配置', () => { + const w = mount(UserPlatformQuotaCell, { + props: { quotas: [item({ platform: 'openai', daily_usage_usd: 5 })] }, + }) + expect(w.html()).toContain('admin.users.platformQuota.cellNotConfigured') + }) + + it('已配置平台渲染 已用/限额,null 档显示 —,金额去尾零', () => { + const w = mount(UserPlatformQuotaCell, { + props: { + quotas: [ + item({ platform: 'anthropic', daily_limit_usd: 100, daily_usage_usd: 30, + weekly_limit_usd: null, weekly_usage_usd: 0, + monthly_limit_usd: 2000, monthly_usage_usd: 90.5 }), + ], + }, + }) + const html = w.html() + expect(html).toContain('anthropic') + expect(html).toContain('30/100') + expect(html).toContain('0/—') + expect(html).toContain('90.5/2000') + }) + + it('多平台按 anthropic→openai→gemini→antigravity 顺序,仅展示有限额的', () => { + const w = mount(UserPlatformQuotaCell, { + props: { + quotas: [ + item({ platform: 'gemini', monthly_limit_usd: 50 }), + item({ platform: 'anthropic', daily_limit_usd: 10 }), + item({ platform: 'openai', daily_usage_usd: 9 }), + ], + }, + }) + const text = w.text() + expect(text.indexOf('anthropic')).toBeLessThan(text.indexOf('gemini')) + expect(text).not.toContain('openai') + }) +}) diff --git a/frontend/src/components/user/dashboard/UserDashboardStats.vue b/frontend/src/components/user/dashboard/UserDashboardStats.vue index 97d2da3d..e7f3449d 100644 --- a/frontend/src/components/user/dashboard/UserDashboardStats.vue +++ b/frontend/src/components/user/dashboard/UserDashboardStats.vue @@ -177,6 +177,46 @@ + + +
+

+ {{ t('dashboard.platformQuota.title') }} +

+ +
@@ -187,11 +227,23 @@ import { computed } from 'vue' import { useI18n } from 'vue-i18n' import Icon from '@/components/icons/Icon.vue' import type { UserDashboardStats as UserStatsType } from '@/api/usage' +import type { PlatformQuotaItem } from '@/types' + +interface FusedPlatformCard { + platform: string + total_actual_cost: number + today_actual_cost: number + total_requests: number + total_tokens: number + isOther?: boolean + quota?: PlatformQuotaItem +} const props = defineProps<{ stats: UserStatsType balance: number isSimple: boolean + platformQuotas?: PlatformQuotaItem[] | null }>() const { t } = useI18n() @@ -213,16 +265,45 @@ const sortedPlatforms = computed(() => { // (group 与 account 都缺 platform)。这里把差值作为"其他"卡片显式展示, // 避免 Row 1 总值与 Row 3 平台拆分加总对不上、用户困惑。 const OTHER_THRESHOLD = 0.0001 -const platformCards = computed(() => { - const cards: Array<{ - platform: string - total_actual_cost: number - today_actual_cost: number - total_requests: number - total_tokens: number - isOther?: boolean - }> = sortedPlatforms.value.map((p) => ({ ...p })) +const platformCards = computed(() => { + // 建立 by_platform Map + const byPlat = new Map() + for (const item of props.stats?.by_platform ?? []) byPlat.set(item.platform, item) + // 建立 quota Map + const byQuota = new Map() + for (const q of props.platformQuotas ?? []) byQuota.set(q.platform, q) + + // union 平台集合。后端 by_platform / quota 接口均不会返回 platform='__other__', + // 无需显式排除;__other__ 由下方差值补差逻辑单独追加。 + const platforms = new Set([...byPlat.keys(), ...byQuota.keys()]) + + const PLATFORM_ORDER = ['anthropic', 'openai', 'gemini', 'antigravity'] + const cards: FusedPlatformCard[] = [] + + for (const p of platforms) { + const stat = byPlat.get(p) + cards.push({ + platform: p, + total_actual_cost: stat?.total_actual_cost ?? 0, + today_actual_cost: stat?.today_actual_cost ?? 0, + total_requests: stat?.total_requests ?? 0, + total_tokens: stat?.total_tokens ?? 0, + quota: byQuota.get(p), + }) + } + + // 排序:按 PLATFORM_ORDER,未知平台按名称排序 + cards.sort((a, b) => { + const ai = PLATFORM_ORDER.indexOf(a.platform) + const bi = PLATFORM_ORDER.indexOf(b.platform) + if (ai === -1 && bi === -1) return a.platform.localeCompare(b.platform) + if (ai === -1) return 1 + if (bi === -1) return -1 + return ai - bi + }) + + // __other__ 补差逻辑:只对 by_platform 有 usage 数据的总和计算 const total = props.stats?.total_actual_cost ?? 0 const today = props.stats?.today_actual_cost ?? 0 const sumTotal = cards.reduce((s, c) => s + c.total_actual_cost, 0) @@ -237,12 +318,62 @@ const platformCards = computed(() => { today_actual_cost: diffToday, total_requests: 0, total_tokens: 0, - isOther: true + isOther: true, }) } + return cards }) +// Quota helpers + +type QuotaWindow = 'daily' | 'weekly' | 'monthly' +type QuotaField = `${QuotaWindow}_limit_usd` | `${QuotaWindow}_usage_usd` | `${QuotaWindow}_window_resets_at` + +function quotaVal(q: PlatformQuotaItem | undefined, key: QuotaField): PlatformQuotaItem[QuotaField] { + return q?.[key] +} + +function hasAnyLimit(q: PlatformQuotaItem | undefined): boolean { + if (!q) return false + return q.daily_limit_usd != null || q.weekly_limit_usd != null || q.monthly_limit_usd != null +} + +function calcPercent(usage: number, limit: number): number { + if (!limit || limit <= 0) return 0 + return Math.min(100, Math.max(0, Math.round((usage / limit) * 100))) +} + +function quotaBarClass(p: number): string { + if (p >= 95) return 'bg-red-500' + if (p >= 75) return 'bg-amber-500' + return 'bg-green-500' +} + +// 与 formatBalance 一致使用 Intl.NumberFormat 做半偶舍入,避免 toFixed 在不同 JS 引擎 +// 下偶发截断而非四舍五入(与后端展示精度不一致)。 +const usdFormatter = new Intl.NumberFormat('en-US', { + minimumFractionDigits: 2, + maximumFractionDigits: 2, +}) +function formatUsd(n: number): string { + if (!Number.isFinite(n)) return '0.00' + return usdFormatter.format(n) +} + +function formatResetTime(iso: string | null | undefined): string { + if (!iso) return '' + const d = new Date(iso) + if (Number.isNaN(d.getTime())) return iso + return d.toLocaleString(undefined, { + month: 'numeric', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + hour12: false, + }) +} + const formatBalance = (b: number) => new Intl.NumberFormat('en-US', { minimumFractionDigits: 2, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 53176a93..bd7ac310 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -634,6 +634,15 @@ export default { platformBreakdownEmpty: 'No platform usage yet', platformCount: '{count} platforms', platformOther: 'Other', + platformQuota: { + title: 'Quota Usage', + daily: 'Daily', + weekly: 'Weekly', + monthly: 'Monthly (30-day rolling)', + resetsAt: 'Resets {time}', + noLimit: 'unlimited', + disabled: 'Disabled', + }, tokenUsageTrend: 'Token Usage Trend', noDataAvailable: 'No data available', model: 'Model', @@ -1793,6 +1802,7 @@ export default { groups: 'Groups', subscriptions: 'Subscriptions', balance: 'Balance', + balancePlatformQuota: 'Balance (Platform Quota)', usage: 'Usage', usageAnthropic: 'Usage (Claude)', usageOpenAI: 'Usage (OpenAI)', @@ -1985,6 +1995,41 @@ export default { failedToReorder: 'Failed to update order', keyExists: 'Attribute key already exists', dragToReorder: 'Drag to reorder' + }, + platformQuota: { + menuItem: 'Platform Quotas', + title: 'Platform Quotas', + subtitle: 'Configure daily / weekly / monthly USD usage limits for each upstream platform for user {email}', + columns: { + platform: 'Platform', + daily: 'Daily (USD)', + weekly: 'Weekly (USD)', + monthly: 'Monthly (USD, 30-day rolling)', + usage: 'Current Usage', + }, + placeholder: 'unlimited', + save: 'Save', + saving: 'Saving...', + cancel: 'Cancel', + clearAll: 'Clear All (remove all limits)', + clearAllConfirm: 'Clear daily / weekly / monthly limits for ALL platforms? All platforms will become "unlimited" with no local undo — you must manually re-enter values before saving.', + reset: { + button: 'Reset window', + confirm: 'Reset the {window} usage for {platform} for this user? This is effective immediately.', + success: 'Reset {platform} {window} usage', + failed: 'Reset failed', + }, + updateSuccess: 'Platform quotas updated', + updateFailed: 'Save failed', + loadFailed: 'Load failed', + hint: 'Empty = no limit for that window.', + windowDaily: 'daily', + windowWeekly: 'weekly', + windowMonthly: 'monthly', + cellNotConfigured: 'Not configured', + cellColumnTooltip: 'Only platforms with a limit are shown', + subscriptionWarning: 'This user has an active subscription. Platform quotas only apply to balance (standard) mode requests; subscription mode requests are not subject to these limits.', + invalidNumber: 'The following fields contain invalid numbers. Please fix them before saving: {fields}', } }, @@ -5467,7 +5512,17 @@ export default { defaultSubscriptionsDuplicate: 'Duplicate subscription group: {groupId}. Each group can only appear once.', subscriptionGroup: 'Subscription Group', - subscriptionValidityDays: 'Validity (days)' + subscriptionValidityDays: 'Validity (days)', + defaultPlatformQuotas: 'Default Platform Quotas (on signup)', + defaultPlatformQuotasHint: 'Automatically assigned to new users on signup; existing users are not affected. Leave blank = unlimited.', + platformQuotaNotice: 'Monthly quota uses a 30-day rolling window, not a calendar month.', + }, + platformQuota: { + platform: 'Platform', + daily: 'Daily (USD)', + weekly: 'Weekly (USD)', + monthly: 'Monthly (USD, 30d rolling)', + placeholder: 'Unlimited', }, claudeCode: { title: 'Claude Code Settings', @@ -6188,7 +6243,9 @@ export default { grantOnFirstBindHint: 'Grant default entitlements when an existing user first binds this source.', defaultSubscriptionsLabel: 'Default subscriptions', defaultSubscriptionsHint: 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.', - noSourceSubscriptions: 'No source-specific default subscriptions configured.' + noSourceSubscriptions: 'No source-specific default subscriptions configured.', + platformQuotasOverride: 'Platform Quota Overrides', + platformQuotasOverrideHint: 'Blank fields inherit the system default. Set to 0 to fully block that window for this auth source.', }, paymentVisibleMethods: { methodLabel: '{title} visible method', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index b293ec67..90f168b4 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -633,6 +633,15 @@ export default { platformBreakdownEmpty: '暂无平台用量', platformCount: '{count} 个平台', platformOther: '其他', + platformQuota: { + title: '配额用量', + daily: '日', + weekly: '周', + monthly: '月(近30天)', + resetsAt: '{time} 重置', + noLimit: '不限制', + disabled: '已禁用', + }, tokenUsageTrend: 'Token 使用趋势', noDataAvailable: '暂无数据', model: '模型', @@ -1814,6 +1823,7 @@ export default { groups: '分组', subscriptions: '订阅分组', balance: '余额', + balancePlatformQuota: '余额(平台配额)', usage: '用量', usageAnthropic: '用量 (Claude)', usageOpenAI: '用量 (OpenAI)', @@ -2038,6 +2048,41 @@ export default { failedToReorder: '更新排序失败', keyExists: '属性键已存在', dragToReorder: '拖拽排序' + }, + platformQuota: { + menuItem: '平台限额', + title: '平台限额', + subtitle: '为用户 {email} 配置各上游平台的日 / 周 / 月用量上限', + columns: { + platform: '平台', + daily: '日 (USD)', + weekly: '周 (USD)', + monthly: '月 (USD, 30天滚动)', + usage: '当前用量', + }, + placeholder: '不限制', + save: '保存', + saving: '保存中...', + cancel: '取消', + clearAll: '全部清空(取消所有限额)', + clearAllConfirm: '确认清空全部平台的日 / 周 / 月限额?所有平台将变为"无限额",本地无法撤销,需要在保存前手动重填。', + reset: { + button: '重置该窗口', + confirm: '确认重置该用户 {platform} 平台的 {window} 用量?此操作立即生效。', + success: '已重置 {platform} {window} 用量', + failed: '重置失败', + }, + updateSuccess: '平台限额已更新', + updateFailed: '保存失败', + loadFailed: '加载失败', + hint: '留空 = 不限制该窗口。', + windowDaily: '日', + windowWeekly: '周', + windowMonthly: '月', + cellNotConfigured: '未配置', + cellColumnTooltip: '仅展示已设限额的平台', + subscriptionWarning: '此用户有活跃订阅,平台限额仅在余额(标准)模式下生效,订阅模式请求不受此限额约束。', + invalidNumber: '以下字段填写不是合法数字,请修正后再保存:{fields}', } }, @@ -5626,7 +5671,17 @@ export default { defaultSubscriptionsEmpty: '未配置默认订阅。新用户不会自动获得订阅套餐。', defaultSubscriptionsDuplicate: '默认订阅存在重复分组:{groupId}。每个分组只能出现一次。', subscriptionGroup: '订阅分组', - subscriptionValidityDays: '有效期(天)' + subscriptionValidityDays: '有效期(天)', + defaultPlatformQuotas: '默认平台限额(注册时分配)', + defaultPlatformQuotasHint: '新用户注册时自动写入平台限额记录;已有用户不受影响。留空 = 该平台该窗口不限制。', + platformQuotaNotice: '月限额为 30 天滚动窗口,非自然月', + }, + platformQuota: { + platform: '平台', + daily: '日限额 (USD)', + weekly: '周限额 (USD)', + monthly: '月限额 (USD, 30天滚动)', + placeholder: '不限', }, claudeCode: { title: 'Claude Code 设置', @@ -6346,7 +6401,9 @@ export default { grantOnFirstBindHint: '已有账号首次绑定该来源时发放默认权益。', defaultSubscriptionsLabel: '默认订阅', defaultSubscriptionsHint: '仅对当前认证来源生效,未配置时不追加来源专属订阅。', - noSourceSubscriptions: '当前来源未配置专属默认订阅。' + noSourceSubscriptions: '当前来源未配置专属默认订阅。', + platformQuotasOverride: '平台限额覆盖', + platformQuotasOverrideHint: '留空的字段继承「系统默认平台限额」;填 0 表示禁止该窗口使用。', }, paymentVisibleMethods: { methodLabel: '{title} 可见方式', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 63b9b14f..632e5108 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -1854,3 +1854,11 @@ export interface UpdateScheduledTestPlanRequest { // Payment types export type { SubscriptionPlan, PaymentOrder, CheckoutInfoResponse } from './payment' + +export type { + PlatformQuotaItem, + PlatformQuotaUpdateItem, + PlatformQuotaPlatform, + PlatformQuotaWindow, + PlatformQuotasResponse, +} from '@/api/admin/users' diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index d3fac4d5..68eb4849 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3262,6 +3262,71 @@ + + +
+
+ +

+ {{ t("admin.settings.defaults.defaultPlatformQuotasHint") }} +

+

+ {{ t("admin.settings.defaults.platformQuotaNotice") }} +

+
+
+ + + + + + + + + + + + + + + + + +
{{ t("admin.settings.platformQuota.platform") }}{{ t("admin.settings.platformQuota.daily") }}{{ t("admin.settings.platformQuota.weekly") }}{{ t("admin.settings.platformQuota.monthly") }}
+ {{ p }} + + + + + + +
+
+
+ @@ -3535,6 +3600,68 @@ + + +
+
+ +

+ {{ t("admin.settings.authSourceDefaults.platformQuotasOverrideHint") }} +

+
+
+ + + + + + + + + + + + + + + + + +
{{ t("admin.settings.platformQuota.platform") }}{{ t("admin.settings.platformQuota.daily") }}{{ t("admin.settings.platformQuota.weekly") }}{{ t("admin.settings.platformQuota.monthly") }}
+ {{ p }} + + + + + + +
+
+
+ @@ -6530,6 +6657,8 @@ import { adminAPI } from "@/api"; import { appendAuthSourceDefaultsToUpdateRequest, buildAuthSourceDefaultsState, + normalizePlatformQuotasMap, + sanitizePlatformQuotasMap, defaultWeChatConnectScopesForMode, deriveWeChatConnectStoredMode, normalizeDefaultSubscriptionSettings, @@ -6541,6 +6670,7 @@ import type { SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, + DefaultPlatformQuotasMap, OpenAIFastPolicyRule, WeChatConnectMode, WebSearchEmulationConfig, @@ -6835,6 +6965,8 @@ type SettingsForm = Omit< google_oauth_client_secret: string; force_email_on_third_party_signup: boolean; openai_advanced_scheduler_enabled: boolean; + // 系统全局平台限额 map;form 内始终归一化为全 4 平台对象(模板非空绑定依赖此不变量) + default_platform_quotas: DefaultPlatformQuotasMap; }; const form = reactive({ @@ -6851,6 +6983,7 @@ const form = reactive({ login_agreement_updated_at: "2026-03-31", login_agreement_documents: defaultLoginAgreementDocuments(), default_balance: 0, + default_platform_quotas: normalizePlatformQuotasMap() as DefaultPlatformQuotasMap, affiliate_rebate_rate: 20, affiliate_rebate_freeze_hours: 0, affiliate_rebate_duration_days: 0, @@ -7659,6 +7792,7 @@ async function loadSettings() { })) : defaultLoginAgreementDocuments(); Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings)); + form.default_platform_quotas = normalizePlatformQuotasMap(settings.default_platform_quotas); form.backend_mode_enabled = settings.backend_mode_enabled; form.default_subscriptions = normalizeDefaultSubscriptionSettings( settings.default_subscriptions, @@ -8212,6 +8346,7 @@ async function saveSettings() { }; } + payload.default_platform_quotas = sanitizePlatformQuotasMap(form.default_platform_quotas); appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults); const updated = await adminAPI.settings.updateSettings(payload); @@ -8222,6 +8357,7 @@ async function saveSettings() { } } Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated)); + form.default_platform_quotas = normalizePlatformQuotasMap(updated.default_platform_quotas); registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( updated.registration_email_suffix_whitelist, diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index 512bae67..93c70c6b 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -420,6 +420,17 @@ + + @@ -673,6 +684,15 @@ {{ t('admin.users.withdraw') }} + + +