feat(quota): 用户 × 平台 USD 配额
为用户在 anthropic/openai/gemini/antigravity 四个平台上提供日/周/月 三个窗口的 USD 配额管控。配额语义:未设置=不限制,0=禁用,>0=美元上限。 两层模型: - 配置层:系统默认配额,以及 email/linuxdo/oidc/wechat/github/google/ dingtalk 七个鉴权来源的默认配额,存于 settings,以嵌套 JSON 整体读写 (系统 1 个 key + 每个来源 1 个 key),整体替换语义。 - 运行时层:user_platform_quota 表按用户记录实际配额,与配置层解耦。 后端:新增 ent schema 与 140_user_platform_quotas.sql 迁移、repository 与 service 端口、计费链路集成、管理端与用户端读写接口。 前端:管理端设置页配额编辑、用户配额管理 Modal、用户 Dashboard 展示、 中英文案。 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
2f70d965bf
commit
6b39b344d8
@ -33,7 +33,7 @@ func main() {
|
||||
}()
|
||||
|
||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@ -63,7 +63,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
|
||||
userRPMCache := repository.NewUserRPMCache(redisClient)
|
||||
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
|
||||
userPlatformQuotaRepository := repository.NewUserPlatformQuotaRepository(client)
|
||||
serviceUserPlatformQuotaRepository := repository.NewUserPlatformQuotaServiceAdapter(userPlatformQuotaRepository)
|
||||
billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig, serviceUserPlatformQuotaRepository)
|
||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||
apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
@ -71,7 +73,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService, serviceUserPlatformQuotaRepository)
|
||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService)
|
||||
@ -85,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService, serviceUserPlatformQuotaRepository)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||
@ -143,9 +145,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService, serviceUserPlatformQuotaRepository)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService, serviceUserPlatformQuotaRepository, billingCache)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
@ -190,7 +192,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, serviceUserPlatformQuotaRepository)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
|
||||
@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
|
||||
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
|
||||
pricingSvc := service.NewPricingService(cfg, nil)
|
||||
emailQueueSvc := service.NewEmailQueueService(nil, 1)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
|
||||
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
|
||||
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
|
||||
|
||||
@ -48,6 +48,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
|
||||
stdsql "database/sql"
|
||||
@ -124,6 +125,8 @@ type Client struct {
|
||||
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||
UserAttributeValue *UserAttributeValueClient
|
||||
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
|
||||
UserPlatformQuota *UserPlatformQuotaClient
|
||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||
UserSubscription *UserSubscriptionClient
|
||||
}
|
||||
@ -170,6 +173,7 @@ func (c *Client) init() {
|
||||
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
|
||||
c.UserAttributeDefinition = NewUserAttributeDefinitionClient(c.config)
|
||||
c.UserAttributeValue = NewUserAttributeValueClient(c.config)
|
||||
c.UserPlatformQuota = NewUserPlatformQuotaClient(c.config)
|
||||
c.UserSubscription = NewUserSubscriptionClient(c.config)
|
||||
}
|
||||
|
||||
@ -296,6 +300,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
|
||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||
}, nil
|
||||
}
|
||||
@ -349,6 +354,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
|
||||
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
|
||||
UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
|
||||
UserAttributeValue: NewUserAttributeValueClient(cfg),
|
||||
UserPlatformQuota: NewUserPlatformQuotaClient(cfg),
|
||||
UserSubscription: NewUserSubscriptionClient(cfg),
|
||||
}, nil
|
||||
}
|
||||
@ -388,7 +394,7 @@ func (c *Client) Use(hooks ...Hook) {
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
|
||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
c.UserPlatformQuota, c.UserSubscription,
|
||||
} {
|
||||
n.Use(hooks...)
|
||||
}
|
||||
@ -407,7 +413,7 @@ func (c *Client) Intercept(interceptors ...Interceptor) {
|
||||
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
|
||||
c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
|
||||
c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
|
||||
c.UserSubscription,
|
||||
c.UserPlatformQuota, c.UserSubscription,
|
||||
} {
|
||||
n.Intercept(interceptors...)
|
||||
}
|
||||
@ -482,6 +488,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
|
||||
return c.UserAttributeDefinition.mutate(ctx, m)
|
||||
case *UserAttributeValueMutation:
|
||||
return c.UserAttributeValue.mutate(ctx, m)
|
||||
case *UserPlatformQuotaMutation:
|
||||
return c.UserPlatformQuota.mutate(ctx, m)
|
||||
case *UserSubscriptionMutation:
|
||||
return c.UserSubscription.mutate(ctx, m)
|
||||
default:
|
||||
@ -5341,6 +5349,22 @@ func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas queries the platform_quotas edge of a User.
|
||||
func (c *UserClient) QueryPlatformQuotas(_m *User) *UserPlatformQuotaQuery {
|
||||
query := (&UserPlatformQuotaClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, id),
|
||||
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
|
||||
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: c.config}).Query()
|
||||
@ -5816,6 +5840,157 @@ func (c *UserAttributeValueClient) mutate(ctx context.Context, m *UserAttributeV
|
||||
}
|
||||
}
|
||||
|
||||
// UserPlatformQuotaClient is a client for the UserPlatformQuota schema.
|
||||
type UserPlatformQuotaClient struct {
|
||||
config
|
||||
}
|
||||
|
||||
// NewUserPlatformQuotaClient returns a client for the UserPlatformQuota from the given config.
|
||||
func NewUserPlatformQuotaClient(c config) *UserPlatformQuotaClient {
|
||||
return &UserPlatformQuotaClient{config: c}
|
||||
}
|
||||
|
||||
// Use adds a list of mutation hooks to the hooks stack.
|
||||
// A call to `Use(f, g, h)` equals to `userplatformquota.Hooks(f(g(h())))`.
|
||||
func (c *UserPlatformQuotaClient) Use(hooks ...Hook) {
|
||||
c.hooks.UserPlatformQuota = append(c.hooks.UserPlatformQuota, hooks...)
|
||||
}
|
||||
|
||||
// Intercept adds a list of query interceptors to the interceptors stack.
|
||||
// A call to `Intercept(f, g, h)` equals to `userplatformquota.Intercept(f(g(h())))`.
|
||||
func (c *UserPlatformQuotaClient) Intercept(interceptors ...Interceptor) {
|
||||
c.inters.UserPlatformQuota = append(c.inters.UserPlatformQuota, interceptors...)
|
||||
}
|
||||
|
||||
// Create returns a builder for creating a UserPlatformQuota entity.
|
||||
func (c *UserPlatformQuotaClient) Create() *UserPlatformQuotaCreate {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpCreate)
|
||||
return &UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// CreateBulk returns a builder for creating a bulk of UserPlatformQuota entities.
|
||||
func (c *UserPlatformQuotaClient) CreateBulk(builders ...*UserPlatformQuotaCreate) *UserPlatformQuotaCreateBulk {
|
||||
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
|
||||
// a builder and applies setFunc on it.
|
||||
func (c *UserPlatformQuotaClient) MapCreateBulk(slice any, setFunc func(*UserPlatformQuotaCreate, int)) *UserPlatformQuotaCreateBulk {
|
||||
rv := reflect.ValueOf(slice)
|
||||
if rv.Kind() != reflect.Slice {
|
||||
return &UserPlatformQuotaCreateBulk{err: fmt.Errorf("calling to UserPlatformQuotaClient.MapCreateBulk with wrong type %T, need slice", slice)}
|
||||
}
|
||||
builders := make([]*UserPlatformQuotaCreate, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
builders[i] = c.Create()
|
||||
setFunc(builders[i], i)
|
||||
}
|
||||
return &UserPlatformQuotaCreateBulk{config: c.config, builders: builders}
|
||||
}
|
||||
|
||||
// Update returns an update builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Update() *UserPlatformQuotaUpdate {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdate)
|
||||
return &UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOne returns an update builder for the given entity.
|
||||
func (c *UserPlatformQuotaClient) UpdateOne(_m *UserPlatformQuota) *UserPlatformQuotaUpdateOne {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuota(_m))
|
||||
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// UpdateOneID returns an update builder for the given id.
|
||||
func (c *UserPlatformQuotaClient) UpdateOneID(id int64) *UserPlatformQuotaUpdateOne {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpUpdateOne, withUserPlatformQuotaID(id))
|
||||
return &UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// Delete returns a delete builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Delete() *UserPlatformQuotaDelete {
|
||||
mutation := newUserPlatformQuotaMutation(c.config, OpDelete)
|
||||
return &UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
|
||||
}
|
||||
|
||||
// DeleteOne returns a builder for deleting the given entity.
|
||||
func (c *UserPlatformQuotaClient) DeleteOne(_m *UserPlatformQuota) *UserPlatformQuotaDeleteOne {
|
||||
return c.DeleteOneID(_m.ID)
|
||||
}
|
||||
|
||||
// DeleteOneID returns a builder for deleting the given entity by its id.
|
||||
func (c *UserPlatformQuotaClient) DeleteOneID(id int64) *UserPlatformQuotaDeleteOne {
|
||||
builder := c.Delete().Where(userplatformquota.ID(id))
|
||||
builder.mutation.id = &id
|
||||
builder.mutation.op = OpDeleteOne
|
||||
return &UserPlatformQuotaDeleteOne{builder}
|
||||
}
|
||||
|
||||
// Query returns a query builder for UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) Query() *UserPlatformQuotaQuery {
|
||||
return &UserPlatformQuotaQuery{
|
||||
config: c.config,
|
||||
ctx: &QueryContext{Type: TypeUserPlatformQuota},
|
||||
inters: c.Interceptors(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a UserPlatformQuota entity by its id.
|
||||
func (c *UserPlatformQuotaClient) Get(ctx context.Context, id int64) (*UserPlatformQuota, error) {
|
||||
return c.Query().Where(userplatformquota.ID(id)).Only(ctx)
|
||||
}
|
||||
|
||||
// GetX is like Get, but panics if an error occurs.
|
||||
func (c *UserPlatformQuotaClient) GetX(ctx context.Context, id int64) *UserPlatformQuota {
|
||||
obj, err := c.Get(ctx, id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// QueryUser queries the user edge of a UserPlatformQuota.
|
||||
func (c *UserPlatformQuotaClient) QueryUser(_m *UserPlatformQuota) *UserQuery {
|
||||
query := (&UserClient{config: c.config}).Query()
|
||||
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
|
||||
id := _m.ID
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, id),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
|
||||
)
|
||||
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
|
||||
return fromV, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// Hooks returns the client hooks.
|
||||
func (c *UserPlatformQuotaClient) Hooks() []Hook {
|
||||
hooks := c.hooks.UserPlatformQuota
|
||||
return append(hooks[:len(hooks):len(hooks)], userplatformquota.Hooks[:]...)
|
||||
}
|
||||
|
||||
// Interceptors returns the client interceptors.
|
||||
func (c *UserPlatformQuotaClient) Interceptors() []Interceptor {
|
||||
inters := c.inters.UserPlatformQuota
|
||||
return append(inters[:len(inters):len(inters)], userplatformquota.Interceptors[:]...)
|
||||
}
|
||||
|
||||
func (c *UserPlatformQuotaClient) mutate(ctx context.Context, m *UserPlatformQuotaMutation) (Value, error) {
|
||||
switch m.Op() {
|
||||
case OpCreate:
|
||||
return (&UserPlatformQuotaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdate:
|
||||
return (&UserPlatformQuotaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpUpdateOne:
|
||||
return (&UserPlatformQuotaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
|
||||
case OpDelete, OpDeleteOne:
|
||||
return (&UserPlatformQuotaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
|
||||
default:
|
||||
return nil, fmt.Errorf("ent: unknown UserPlatformQuota mutation op: %q", m.Op())
|
||||
}
|
||||
}
|
||||
|
||||
// UserSubscriptionClient is a client for the UserSubscription schema.
|
||||
type UserSubscriptionClient struct {
|
||||
config
|
||||
@ -6025,7 +6200,8 @@ type (
|
||||
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
|
||||
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
|
||||
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
|
||||
UserSubscription []ent.Hook
|
||||
}
|
||||
inters struct {
|
||||
APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
|
||||
@ -6035,7 +6211,8 @@ type (
|
||||
PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
|
||||
PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
|
||||
TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
|
||||
UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
|
||||
UserAttributeDefinition, UserAttributeValue, UserPlatformQuota,
|
||||
UserSubscription []ent.Interceptor
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -139,6 +140,7 @@ func checkColumn(t, c string) error {
|
||||
userallowedgroup.Table: userallowedgroup.ValidColumn,
|
||||
userattributedefinition.Table: userattributedefinition.ValidColumn,
|
||||
userattributevalue.Table: userattributevalue.ValidColumn,
|
||||
userplatformquota.Table: userplatformquota.ValidColumn,
|
||||
usersubscription.Table: usersubscription.ValidColumn,
|
||||
})
|
||||
})
|
||||
|
||||
@ -405,6 +405,18 @@ func (f UserAttributeValueFunc) Mutate(ctx context.Context, m ent.Mutation) (ent
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserAttributeValueMutation", m)
|
||||
}
|
||||
|
||||
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary
|
||||
// function as UserPlatformQuota mutator.
|
||||
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaMutation) (ent.Value, error)
|
||||
|
||||
// Mutate calls f(ctx, m).
|
||||
func (f UserPlatformQuotaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
|
||||
if mv, ok := m.(*ent.UserPlatformQuotaMutation); ok {
|
||||
return f(ctx, mv)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UserPlatformQuotaMutation", m)
|
||||
}
|
||||
|
||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary
|
||||
// function as UserSubscription mutator.
|
||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionMutation) (ent.Value, error)
|
||||
|
||||
@ -42,6 +42,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -992,6 +993,33 @@ func (f TraverseUserAttributeValue) Traverse(ctx context.Context, q ent.Query) e
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.UserAttributeValueQuery", q)
|
||||
}
|
||||
|
||||
// The UserPlatformQuotaFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UserPlatformQuotaFunc func(context.Context, *ent.UserPlatformQuotaQuery) (ent.Value, error)
|
||||
|
||||
// Query calls f(ctx, q).
|
||||
func (f UserPlatformQuotaFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
|
||||
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
|
||||
}
|
||||
|
||||
// The TraverseUserPlatformQuota type is an adapter to allow the use of ordinary function as Traverser.
|
||||
type TraverseUserPlatformQuota func(context.Context, *ent.UserPlatformQuotaQuery) error
|
||||
|
||||
// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
|
||||
func (f TraverseUserPlatformQuota) Intercept(next ent.Querier) ent.Querier {
|
||||
return next
|
||||
}
|
||||
|
||||
// Traverse calls f(ctx, q).
|
||||
func (f TraverseUserPlatformQuota) Traverse(ctx context.Context, q ent.Query) error {
|
||||
if q, ok := q.(*ent.UserPlatformQuotaQuery); ok {
|
||||
return f(ctx, q)
|
||||
}
|
||||
return fmt.Errorf("unexpected query type %T. expect *ent.UserPlatformQuotaQuery", q)
|
||||
}
|
||||
|
||||
// The UserSubscriptionFunc type is an adapter to allow the use of ordinary function as a Querier.
|
||||
type UserSubscriptionFunc func(context.Context, *ent.UserSubscriptionQuery) (ent.Value, error)
|
||||
|
||||
@ -1088,6 +1116,8 @@ func NewQuery(q ent.Query) (Query, error) {
|
||||
return &query[*ent.UserAttributeDefinitionQuery, predicate.UserAttributeDefinition, userattributedefinition.OrderOption]{typ: ent.TypeUserAttributeDefinition, tq: q}, nil
|
||||
case *ent.UserAttributeValueQuery:
|
||||
return &query[*ent.UserAttributeValueQuery, predicate.UserAttributeValue, userattributevalue.OrderOption]{typ: ent.TypeUserAttributeValue, tq: q}, nil
|
||||
case *ent.UserPlatformQuotaQuery:
|
||||
return &query[*ent.UserPlatformQuotaQuery, predicate.UserPlatformQuota, userplatformquota.OrderOption]{typ: ent.TypeUserPlatformQuota, tq: q}, nil
|
||||
case *ent.UserSubscriptionQuery:
|
||||
return &query[*ent.UserSubscriptionQuery, predicate.UserSubscription, usersubscription.OrderOption]{typ: ent.TypeUserSubscription, tq: q}, nil
|
||||
default:
|
||||
|
||||
@ -1612,6 +1612,53 @@ var (
|
||||
},
|
||||
},
|
||||
}
|
||||
// UserPlatformQuotasColumns holds the columns for the "user_platform_quotas" table.
|
||||
UserPlatformQuotasColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "platform", Type: field.TypeString, Size: 32},
|
||||
{Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "daily_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "weekly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "monthly_usage_usd", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
|
||||
{Name: "daily_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "weekly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "monthly_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "user_id", Type: field.TypeInt64},
|
||||
}
|
||||
// UserPlatformQuotasTable holds the schema information for the "user_platform_quotas" table.
|
||||
UserPlatformQuotasTable = &schema.Table{
|
||||
Name: "user_platform_quotas",
|
||||
Columns: UserPlatformQuotasColumns,
|
||||
PrimaryKey: []*schema.Column{UserPlatformQuotasColumns[0]},
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "user_platform_quotas_users_platform_quotas",
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
},
|
||||
Indexes: []*schema.Index{
|
||||
{
|
||||
Name: "userplatformquota_user_id_platform",
|
||||
Unique: true,
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14], UserPlatformQuotasColumns[4]},
|
||||
Annotation: &entsql.IndexAnnotation{
|
||||
Where: "deleted_at IS NULL",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "userplatformquota_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UserPlatformQuotasColumns[14]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// UserSubscriptionsColumns holds the columns for the "user_subscriptions" table.
|
||||
UserSubscriptionsColumns = []*schema.Column{
|
||||
{Name: "id", Type: field.TypeInt64, Increment: true},
|
||||
@ -1736,6 +1783,7 @@ var (
|
||||
UserAllowedGroupsTable,
|
||||
UserAttributeDefinitionsTable,
|
||||
UserAttributeValuesTable,
|
||||
UserPlatformQuotasTable,
|
||||
UserSubscriptionsTable,
|
||||
}
|
||||
)
|
||||
@ -1869,6 +1917,10 @@ func init() {
|
||||
UserAttributeValuesTable.Annotation = &entsql.Annotation{
|
||||
Table: "user_attribute_values",
|
||||
}
|
||||
UserPlatformQuotasTable.ForeignKeys[0].RefTable = UsersTable
|
||||
UserPlatformQuotasTable.Annotation = &entsql.Annotation{
|
||||
Table: "user_platform_quotas",
|
||||
}
|
||||
UserSubscriptionsTable.ForeignKeys[0].RefTable = GroupsTable
|
||||
UserSubscriptionsTable.ForeignKeys[1].RefTable = UsersTable
|
||||
UserSubscriptionsTable.ForeignKeys[2].RefTable = UsersTable
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -105,5 +105,8 @@ type UserAttributeDefinition func(*sql.Selector)
|
||||
// UserAttributeValue is the predicate function for userattributevalue builders.
|
||||
type UserAttributeValue func(*sql.Selector)
|
||||
|
||||
// UserPlatformQuota is the predicate function for userplatformquota builders.
|
||||
type UserPlatformQuota func(*sql.Selector)
|
||||
|
||||
// UserSubscription is the predicate function for usersubscription builders.
|
||||
type UserSubscription func(*sql.Selector)
|
||||
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
)
|
||||
@ -1997,6 +1998,56 @@ func init() {
|
||||
userattributevalueDescValue := userattributevalueFields[2].Descriptor()
|
||||
// userattributevalue.DefaultValue holds the default value on creation for the value field.
|
||||
userattributevalue.DefaultValue = userattributevalueDescValue.Default.(string)
|
||||
userplatformquotaMixin := schema.UserPlatformQuota{}.Mixin()
|
||||
userplatformquotaMixinHooks1 := userplatformquotaMixin[1].Hooks()
|
||||
userplatformquota.Hooks[0] = userplatformquotaMixinHooks1[0]
|
||||
userplatformquotaMixinInters1 := userplatformquotaMixin[1].Interceptors()
|
||||
userplatformquota.Interceptors[0] = userplatformquotaMixinInters1[0]
|
||||
userplatformquotaMixinFields0 := userplatformquotaMixin[0].Fields()
|
||||
_ = userplatformquotaMixinFields0
|
||||
userplatformquotaFields := schema.UserPlatformQuota{}.Fields()
|
||||
_ = userplatformquotaFields
|
||||
// userplatformquotaDescCreatedAt is the schema descriptor for created_at field.
|
||||
userplatformquotaDescCreatedAt := userplatformquotaMixinFields0[0].Descriptor()
|
||||
// userplatformquota.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
userplatformquota.DefaultCreatedAt = userplatformquotaDescCreatedAt.Default.(func() time.Time)
|
||||
// userplatformquotaDescUpdatedAt is the schema descriptor for updated_at field.
|
||||
userplatformquotaDescUpdatedAt := userplatformquotaMixinFields0[1].Descriptor()
|
||||
// userplatformquota.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||
userplatformquota.DefaultUpdatedAt = userplatformquotaDescUpdatedAt.Default.(func() time.Time)
|
||||
// userplatformquota.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||
userplatformquota.UpdateDefaultUpdatedAt = userplatformquotaDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
// userplatformquotaDescPlatform is the schema descriptor for platform field.
|
||||
userplatformquotaDescPlatform := userplatformquotaFields[1].Descriptor()
|
||||
// userplatformquota.PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
|
||||
userplatformquota.PlatformValidator = func() func(string) error {
|
||||
validators := userplatformquotaDescPlatform.Validators
|
||||
fns := [...]func(string) error{
|
||||
validators[0].(func(string) error),
|
||||
validators[1].(func(string) error),
|
||||
validators[2].(func(string) error),
|
||||
}
|
||||
return func(platform string) error {
|
||||
for _, fn := range fns {
|
||||
if err := fn(platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// userplatformquotaDescDailyUsageUsd is the schema descriptor for daily_usage_usd field.
|
||||
userplatformquotaDescDailyUsageUsd := userplatformquotaFields[5].Descriptor()
|
||||
// userplatformquota.DefaultDailyUsageUsd holds the default value on creation for the daily_usage_usd field.
|
||||
userplatformquota.DefaultDailyUsageUsd = userplatformquotaDescDailyUsageUsd.Default.(float64)
|
||||
// userplatformquotaDescWeeklyUsageUsd is the schema descriptor for weekly_usage_usd field.
|
||||
userplatformquotaDescWeeklyUsageUsd := userplatformquotaFields[6].Descriptor()
|
||||
// userplatformquota.DefaultWeeklyUsageUsd holds the default value on creation for the weekly_usage_usd field.
|
||||
userplatformquota.DefaultWeeklyUsageUsd = userplatformquotaDescWeeklyUsageUsd.Default.(float64)
|
||||
// userplatformquotaDescMonthlyUsageUsd is the schema descriptor for monthly_usage_usd field.
|
||||
userplatformquotaDescMonthlyUsageUsd := userplatformquotaFields[7].Descriptor()
|
||||
// userplatformquota.DefaultMonthlyUsageUsd holds the default value on creation for the monthly_usage_usd field.
|
||||
userplatformquota.DefaultMonthlyUsageUsd = userplatformquotaDescMonthlyUsageUsd.Default.(float64)
|
||||
usersubscriptionMixin := schema.UserSubscription{}.Mixin()
|
||||
usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks()
|
||||
usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0]
|
||||
|
||||
@ -131,6 +131,7 @@ func (User) Edges() []ent.Edge {
|
||||
edge.To("auth_identities", AuthIdentity.Type).
|
||||
Annotations(entsql.OnDelete(entsql.Cascade)),
|
||||
edge.To("pending_auth_sessions", PendingAuthSession.Type),
|
||||
edge.To("platform_quotas", UserPlatformQuota.Type),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
113
backend/ent/schema/user_platform_quota.go
Normal file
113
backend/ent/schema/user_platform_quota.go
Normal file
@ -0,0 +1,113 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/entsql"
|
||||
"entgo.io/ent/schema"
|
||||
"entgo.io/ent/schema/edge"
|
||||
"entgo.io/ent/schema/field"
|
||||
"entgo.io/ent/schema/index"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
)
|
||||
|
||||
// UserPlatformQuota holds the schema definition for per-user per-platform quota.
|
||||
type UserPlatformQuota struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Annotations() []schema.Annotation {
|
||||
return []schema.Annotation{
|
||||
entsql.Annotation{Table: "user_platform_quotas"},
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Mixin() []ent.Mixin {
|
||||
return []ent.Mixin{
|
||||
mixins.TimeMixin{},
|
||||
mixins.SoftDeleteMixin{},
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Int64("user_id"),
|
||||
field.String("platform").
|
||||
MaxLen(32).
|
||||
NotEmpty().
|
||||
Validate(func(s string) error {
|
||||
// 注意:平台列表的单一权威源为 service.AllowedQuotaPlatforms;
|
||||
// 此处为 ent 构建期约束,需与 service.AllowedQuotaPlatforms 保持同步。
|
||||
switch s {
|
||||
case "anthropic", "openai", "gemini", "antigravity":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("platform %q is not allowed", s)
|
||||
}
|
||||
}),
|
||||
|
||||
// 日 / 周 / 月 USD 上限:
|
||||
// nil / not set → 无限额(完全放行)
|
||||
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
|
||||
// > 0 → USD 限额上限
|
||||
field.Float("daily_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("weekly_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("monthly_limit_usd").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
|
||||
// 当前窗口已用量(USD,preflight 时与 limit 比较)
|
||||
field.Float("daily_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("weekly_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
field.Float("monthly_usage_usd").
|
||||
Default(0).
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}),
|
||||
|
||||
// 窗口起点(NULL = 首次还未初始化,由 InitWindowStarts 用 COALESCE 兜底)
|
||||
field.Time("daily_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("weekly_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("monthly_window_start").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Edges() []ent.Edge {
|
||||
return []ent.Edge{
|
||||
edge.From("user", User.Type).
|
||||
Ref("platform_quotas").
|
||||
Field("user_id").
|
||||
Unique().
|
||||
Required(),
|
||||
}
|
||||
}
|
||||
|
||||
func (UserPlatformQuota) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// 软删除友好:只对未删记录唯一
|
||||
index.Fields("user_id", "platform").
|
||||
Unique().
|
||||
Annotations(entsql.IndexWhere("deleted_at IS NULL")),
|
||||
index.Fields("user_id"),
|
||||
}
|
||||
}
|
||||
@ -80,6 +80,8 @@ type Tx struct {
|
||||
UserAttributeDefinition *UserAttributeDefinitionClient
|
||||
// UserAttributeValue is the client for interacting with the UserAttributeValue builders.
|
||||
UserAttributeValue *UserAttributeValueClient
|
||||
// UserPlatformQuota is the client for interacting with the UserPlatformQuota builders.
|
||||
UserPlatformQuota *UserPlatformQuotaClient
|
||||
// UserSubscription is the client for interacting with the UserSubscription builders.
|
||||
UserSubscription *UserSubscriptionClient
|
||||
|
||||
@ -246,6 +248,7 @@ func (tx *Tx) init() {
|
||||
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
|
||||
tx.UserAttributeDefinition = NewUserAttributeDefinitionClient(tx.config)
|
||||
tx.UserAttributeValue = NewUserAttributeValueClient(tx.config)
|
||||
tx.UserPlatformQuota = NewUserPlatformQuotaClient(tx.config)
|
||||
tx.UserSubscription = NewUserSubscriptionClient(tx.config)
|
||||
}
|
||||
|
||||
|
||||
@ -95,11 +95,13 @@ type UserEdges struct {
|
||||
AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
|
||||
// PendingAuthSessions holds the value of the pending_auth_sessions edge.
|
||||
PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
|
||||
// PlatformQuotas holds the value of the platform_quotas edge.
|
||||
PlatformQuotas []*UserPlatformQuota `json:"platform_quotas,omitempty"`
|
||||
// UserAllowedGroups holds the value of the user_allowed_groups edge.
|
||||
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [13]bool
|
||||
loadedTypes [14]bool
|
||||
}
|
||||
|
||||
// APIKeysOrErr returns the APIKeys value or an error if the edge
|
||||
@ -210,10 +212,19 @@ func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
|
||||
return nil, &NotLoadedError{edge: "pending_auth_sessions"}
|
||||
}
|
||||
|
||||
// PlatformQuotasOrErr returns the PlatformQuotas value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) PlatformQuotasOrErr() ([]*UserPlatformQuota, error) {
|
||||
if e.loadedTypes[12] {
|
||||
return e.PlatformQuotas, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "platform_quotas"}
|
||||
}
|
||||
|
||||
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
|
||||
// was not loaded in eager-loading.
|
||||
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
|
||||
if e.loadedTypes[12] {
|
||||
if e.loadedTypes[13] {
|
||||
return e.UserAllowedGroups, nil
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user_allowed_groups"}
|
||||
@ -472,6 +483,11 @@ func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
|
||||
return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas queries the "platform_quotas" edge of the User entity.
|
||||
func (_m *User) QueryPlatformQuotas() *UserPlatformQuotaQuery {
|
||||
return NewUserClient(_m.config).QueryPlatformQuotas(_m)
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
|
||||
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
|
||||
|
||||
@ -85,6 +85,8 @@ const (
|
||||
EdgeAuthIdentities = "auth_identities"
|
||||
// EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
|
||||
EdgePendingAuthSessions = "pending_auth_sessions"
|
||||
// EdgePlatformQuotas holds the string denoting the platform_quotas edge name in mutations.
|
||||
EdgePlatformQuotas = "platform_quotas"
|
||||
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
|
||||
EdgeUserAllowedGroups = "user_allowed_groups"
|
||||
// Table holds the table name of the user in the database.
|
||||
@ -171,6 +173,13 @@ const (
|
||||
PendingAuthSessionsInverseTable = "pending_auth_sessions"
|
||||
// PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
|
||||
PendingAuthSessionsColumn = "target_user_id"
|
||||
// PlatformQuotasTable is the table that holds the platform_quotas relation/edge.
|
||||
PlatformQuotasTable = "user_platform_quotas"
|
||||
// PlatformQuotasInverseTable is the table name for the UserPlatformQuota entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "userplatformquota" package.
|
||||
PlatformQuotasInverseTable = "user_platform_quotas"
|
||||
// PlatformQuotasColumn is the table column denoting the platform_quotas relation/edge.
|
||||
PlatformQuotasColumn = "user_id"
|
||||
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
|
||||
UserAllowedGroupsTable = "user_allowed_groups"
|
||||
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
|
||||
@ -569,6 +578,20 @@ func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpti
|
||||
}
|
||||
}
|
||||
|
||||
// ByPlatformQuotasCount orders the results by platform_quotas count.
|
||||
func ByPlatformQuotasCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborsCount(s, newPlatformQuotasStep(), opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByPlatformQuotas orders the results by platform_quotas terms.
|
||||
func ByPlatformQuotas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newPlatformQuotasStep(), append([]sql.OrderTerm{term}, terms...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
|
||||
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
@ -666,6 +689,13 @@ func newPendingAuthSessionsStep() *sqlgraph.Step {
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
|
||||
)
|
||||
}
|
||||
func newPlatformQuotasStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(PlatformQuotasInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
|
||||
)
|
||||
}
|
||||
func newUserAllowedGroupsStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
|
||||
@ -1616,6 +1616,29 @@ func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate
|
||||
})
|
||||
}
|
||||
|
||||
// HasPlatformQuotas applies the HasEdge predicate on the "platform_quotas" edge.
|
||||
func HasPlatformQuotas() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, PlatformQuotasTable, PlatformQuotasColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasPlatformQuotasWith applies the HasEdge predicate on the "platform_quotas" edge with a given conditions (other predicates).
|
||||
func HasPlatformQuotasWith(preds ...predicate.UserPlatformQuota) predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
step := newPlatformQuotasStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
|
||||
func HasUserAllowedGroups() predicate.User {
|
||||
return predicate.User(func(s *sql.Selector) {
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -519,6 +520,21 @@ func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCrea
|
||||
return _c.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_c *UserCreate) AddPlatformQuotaIDs(ids ...int64) *UserCreate {
|
||||
_c.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_c *UserCreate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserCreate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _c.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_c *UserCreate) Mutation() *UserMutation {
|
||||
return _c.mutation
|
||||
@ -1023,6 +1039,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
if nodes := _c.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges = append(_spec.Edges, edge)
|
||||
}
|
||||
return _node, _spec
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -48,6 +49,7 @@ type UserQuery struct {
|
||||
withPaymentOrders *PaymentOrderQuery
|
||||
withAuthIdentities *AuthIdentityQuery
|
||||
withPendingAuthSessions *PendingAuthSessionQuery
|
||||
withPlatformQuotas *UserPlatformQuotaQuery
|
||||
withUserAllowedGroups *UserAllowedGroupQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
@ -350,6 +352,28 @@ func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryPlatformQuotas chains the current query on the "platform_quotas" edge.
|
||||
func (_q *UserQuery) QueryPlatformQuotas() *UserPlatformQuotaQuery {
|
||||
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(user.Table, user.FieldID, selector),
|
||||
sqlgraph.To(userplatformquota.Table, userplatformquota.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.O2M, false, user.PlatformQuotasTable, user.PlatformQuotasColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
|
||||
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
|
||||
query := (&UserAllowedGroupClient{config: _q.config}).Query()
|
||||
@ -576,6 +600,7 @@ func (_q *UserQuery) Clone() *UserQuery {
|
||||
withPaymentOrders: _q.withPaymentOrders.Clone(),
|
||||
withAuthIdentities: _q.withAuthIdentities.Clone(),
|
||||
withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
|
||||
withPlatformQuotas: _q.withPlatformQuotas.Clone(),
|
||||
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
@ -715,6 +740,17 @@ func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQue
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithPlatformQuotas tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "platform_quotas" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithPlatformQuotas(opts ...func(*UserPlatformQuotaQuery)) *UserQuery {
|
||||
query := (&UserPlatformQuotaClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withPlatformQuotas = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
|
||||
@ -804,7 +840,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
var (
|
||||
nodes = []*User{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [13]bool{
|
||||
loadedTypes = [14]bool{
|
||||
_q.withAPIKeys != nil,
|
||||
_q.withRedeemCodes != nil,
|
||||
_q.withSubscriptions != nil,
|
||||
@ -817,6 +853,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
_q.withPaymentOrders != nil,
|
||||
_q.withAuthIdentities != nil,
|
||||
_q.withPendingAuthSessions != nil,
|
||||
_q.withPlatformQuotas != nil,
|
||||
_q.withUserAllowedGroups != nil,
|
||||
}
|
||||
)
|
||||
@ -929,6 +966,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withPlatformQuotas; query != nil {
|
||||
if err := _q.loadPlatformQuotas(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.PlatformQuotas = []*UserPlatformQuota{} },
|
||||
func(n *User, e *UserPlatformQuota) { n.Edges.PlatformQuotas = append(n.Edges.PlatformQuotas, e) }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if query := _q.withUserAllowedGroups; query != nil {
|
||||
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
|
||||
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
|
||||
@ -1339,6 +1383,36 @@ func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *Pending
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadPlatformQuotas(ctx context.Context, query *UserPlatformQuotaQuery, nodes []*User, init func(*User), assign func(*User, *UserPlatformQuota)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
for i := range nodes {
|
||||
fks = append(fks, nodes[i].ID)
|
||||
nodeids[nodes[i].ID] = nodes[i]
|
||||
if init != nil {
|
||||
init(nodes[i])
|
||||
}
|
||||
}
|
||||
if len(query.ctx.Fields) > 0 {
|
||||
query.ctx.AppendFieldOnce(userplatformquota.FieldUserID)
|
||||
}
|
||||
query.Where(predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
s.Where(sql.InValues(s.C(user.PlatformQuotasColumn), fks...))
|
||||
}))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
fk := n.UserID
|
||||
node, ok := nodeids[fk]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
|
||||
}
|
||||
assign(node, n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
|
||||
fks := make([]driver.Value, 0, len(nodes))
|
||||
nodeids := make(map[int64]*User)
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/usagelog"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
)
|
||||
|
||||
@ -590,6 +591,21 @@ func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpda
|
||||
return _u.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_u *UserUpdate) AddPlatformQuotaIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdate) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdate) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@ -847,6 +863,27 @@ func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserU
|
||||
return _u.RemovePendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdate) ClearPlatformQuotas() *UserUpdate {
|
||||
_u.mutation.ClearPlatformQuotas()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
|
||||
func (_u *UserUpdate) RemovePlatformQuotaIDs(ids ...int64) *UserUpdate {
|
||||
_u.mutation.RemovePlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
|
||||
func (_u *UserUpdate) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdate {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
@ -1587,6 +1624,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{user.Label}
|
||||
@ -2158,6 +2240,21 @@ func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserU
|
||||
return _u.AddPendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// AddPlatformQuotaIDs adds the "platform_quotas" edge to the UserPlatformQuota entity by IDs.
|
||||
func (_u *UserUpdateOne) AddPlatformQuotaIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.AddPlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddPlatformQuotas adds the "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdateOne) AddPlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.AddPlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Mutation returns the UserMutation object of the builder.
|
||||
func (_u *UserUpdateOne) Mutation() *UserMutation {
|
||||
return _u.mutation
|
||||
@ -2415,6 +2512,27 @@ func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *Us
|
||||
return _u.RemovePendingAuthSessionIDs(ids...)
|
||||
}
|
||||
|
||||
// ClearPlatformQuotas clears all "platform_quotas" edges to the UserPlatformQuota entity.
|
||||
func (_u *UserUpdateOne) ClearPlatformQuotas() *UserUpdateOne {
|
||||
_u.mutation.ClearPlatformQuotas()
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotaIDs removes the "platform_quotas" edge to UserPlatformQuota entities by IDs.
|
||||
func (_u *UserUpdateOne) RemovePlatformQuotaIDs(ids ...int64) *UserUpdateOne {
|
||||
_u.mutation.RemovePlatformQuotaIDs(ids...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// RemovePlatformQuotas removes "platform_quotas" edges to UserPlatformQuota entities.
|
||||
func (_u *UserUpdateOne) RemovePlatformQuotas(v ...*UserPlatformQuota) *UserUpdateOne {
|
||||
ids := make([]int64, len(v))
|
||||
for i := range v {
|
||||
ids[i] = v[i].ID
|
||||
}
|
||||
return _u.RemovePlatformQuotaIDs(ids...)
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserUpdate builder.
|
||||
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
@ -3185,6 +3303,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.RemovedPlatformQuotasIDs(); len(nodes) > 0 && !_u.mutation.PlatformQuotasCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.PlatformQuotasIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
Inverse: false,
|
||||
Table: user.PlatformQuotasTable,
|
||||
Columns: []string{user.PlatformQuotasColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &User{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
|
||||
301
backend/ent/userplatformquota.go
Normal file
301
backend/ent/userplatformquota.go
Normal file
@ -0,0 +1,301 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuota is the model entity for the UserPlatformQuota schema.
|
||||
type UserPlatformQuota struct {
|
||||
config `json:"-"`
|
||||
// ID of the ent.
|
||||
ID int64 `json:"id,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// UpdatedAt holds the value of the "updated_at" field.
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
// DeletedAt holds the value of the "deleted_at" field.
|
||||
DeletedAt *time.Time `json:"deleted_at,omitempty"`
|
||||
// UserID holds the value of the "user_id" field.
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
// Platform holds the value of the "platform" field.
|
||||
Platform string `json:"platform,omitempty"`
|
||||
// DailyLimitUsd holds the value of the "daily_limit_usd" field.
|
||||
DailyLimitUsd *float64 `json:"daily_limit_usd,omitempty"`
|
||||
// WeeklyLimitUsd holds the value of the "weekly_limit_usd" field.
|
||||
WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
// MonthlyLimitUsd holds the value of the "monthly_limit_usd" field.
|
||||
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
// DailyUsageUsd holds the value of the "daily_usage_usd" field.
|
||||
DailyUsageUsd float64 `json:"daily_usage_usd,omitempty"`
|
||||
// WeeklyUsageUsd holds the value of the "weekly_usage_usd" field.
|
||||
WeeklyUsageUsd float64 `json:"weekly_usage_usd,omitempty"`
|
||||
// MonthlyUsageUsd holds the value of the "monthly_usage_usd" field.
|
||||
MonthlyUsageUsd float64 `json:"monthly_usage_usd,omitempty"`
|
||||
// DailyWindowStart holds the value of the "daily_window_start" field.
|
||||
DailyWindowStart *time.Time `json:"daily_window_start,omitempty"`
|
||||
// WeeklyWindowStart holds the value of the "weekly_window_start" field.
|
||||
WeeklyWindowStart *time.Time `json:"weekly_window_start,omitempty"`
|
||||
// MonthlyWindowStart holds the value of the "monthly_window_start" field.
|
||||
MonthlyWindowStart *time.Time `json:"monthly_window_start,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the UserPlatformQuotaQuery when eager-loading is set.
|
||||
Edges UserPlatformQuotaEdges `json:"edges"`
|
||||
selectValues sql.SelectValues
|
||||
}
|
||||
|
||||
// UserPlatformQuotaEdges holds the relations/edges for other nodes in the graph.
|
||||
type UserPlatformQuotaEdges struct {
|
||||
// User holds the value of the user edge.
|
||||
User *User `json:"user,omitempty"`
|
||||
// loadedTypes holds the information for reporting if a
|
||||
// type was loaded (or requested) in eager-loading or not.
|
||||
loadedTypes [1]bool
|
||||
}
|
||||
|
||||
// UserOrErr returns the User value or an error if the edge
|
||||
// was not loaded in eager-loading, or loaded but was not found.
|
||||
func (e UserPlatformQuotaEdges) UserOrErr() (*User, error) {
|
||||
if e.User != nil {
|
||||
return e.User, nil
|
||||
} else if e.loadedTypes[0] {
|
||||
return nil, &NotFoundError{label: user.Label}
|
||||
}
|
||||
return nil, &NotLoadedError{edge: "user"}
|
||||
}
|
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*UserPlatformQuota) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case userplatformquota.FieldDailyLimitUsd, userplatformquota.FieldWeeklyLimitUsd, userplatformquota.FieldMonthlyLimitUsd, userplatformquota.FieldDailyUsageUsd, userplatformquota.FieldWeeklyUsageUsd, userplatformquota.FieldMonthlyUsageUsd:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case userplatformquota.FieldID, userplatformquota.FieldUserID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case userplatformquota.FieldPlatform:
|
||||
values[i] = new(sql.NullString)
|
||||
case userplatformquota.FieldCreatedAt, userplatformquota.FieldUpdatedAt, userplatformquota.FieldDeletedAt, userplatformquota.FieldDailyWindowStart, userplatformquota.FieldWeeklyWindowStart, userplatformquota.FieldMonthlyWindowStart:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
}
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the UserPlatformQuota fields.
|
||||
func (_m *UserPlatformQuota) assignValues(columns []string, values []any) error {
|
||||
if m, n := len(values), len(columns); m < n {
|
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
|
||||
}
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case userplatformquota.FieldID:
|
||||
value, ok := values[i].(*sql.NullInt64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field id", value)
|
||||
}
|
||||
_m.ID = int64(value.Int64)
|
||||
case userplatformquota.FieldCreatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldUpdatedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field updated_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UpdatedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldDeletedAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field deleted_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DeletedAt = new(time.Time)
|
||||
*_m.DeletedAt = value.Time
|
||||
}
|
||||
case userplatformquota.FieldUserID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field user_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.UserID = value.Int64
|
||||
}
|
||||
case userplatformquota.FieldPlatform:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field platform", values[i])
|
||||
} else if value.Valid {
|
||||
_m.Platform = value.String
|
||||
}
|
||||
case userplatformquota.FieldDailyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyLimitUsd = new(float64)
|
||||
*_m.DailyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldWeeklyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyLimitUsd = new(float64)
|
||||
*_m.WeeklyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldMonthlyLimitUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_limit_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyLimitUsd = new(float64)
|
||||
*_m.MonthlyLimitUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldDailyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldWeeklyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldMonthlyUsageUsd:
|
||||
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_usage_usd", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyUsageUsd = value.Float64
|
||||
}
|
||||
case userplatformquota.FieldDailyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field daily_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.DailyWindowStart = new(time.Time)
|
||||
*_m.DailyWindowStart = value.Time
|
||||
}
|
||||
case userplatformquota.FieldWeeklyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field weekly_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.WeeklyWindowStart = new(time.Time)
|
||||
*_m.WeeklyWindowStart = value.Time
|
||||
}
|
||||
case userplatformquota.FieldMonthlyWindowStart:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field monthly_window_start", values[i])
|
||||
} else if value.Valid {
|
||||
_m.MonthlyWindowStart = new(time.Time)
|
||||
*_m.MonthlyWindowStart = value.Time
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the UserPlatformQuota.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *UserPlatformQuota) Value(name string) (ent.Value, error) {
|
||||
return _m.selectValues.Get(name)
|
||||
}
|
||||
|
||||
// QueryUser queries the "user" edge of the UserPlatformQuota entity.
|
||||
func (_m *UserPlatformQuota) QueryUser() *UserQuery {
|
||||
return NewUserPlatformQuotaClient(_m.config).QueryUser(_m)
|
||||
}
|
||||
|
||||
// Update returns a builder for updating this UserPlatformQuota.
|
||||
// Note that you need to call UserPlatformQuota.Unwrap() before calling this method if this UserPlatformQuota
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *UserPlatformQuota) Update() *UserPlatformQuotaUpdateOne {
|
||||
return NewUserPlatformQuotaClient(_m.config).UpdateOne(_m)
|
||||
}
|
||||
|
||||
// Unwrap unwraps the UserPlatformQuota entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *UserPlatformQuota) Unwrap() *UserPlatformQuota {
|
||||
_tx, ok := _m.config.driver.(*txDriver)
|
||||
if !ok {
|
||||
panic("ent: UserPlatformQuota is not a transactional entity")
|
||||
}
|
||||
_m.config.driver = _tx.drv
|
||||
return _m
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *UserPlatformQuota) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("UserPlatformQuota(")
|
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("updated_at=")
|
||||
builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DeletedAt; v != nil {
|
||||
builder.WriteString("deleted_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("user_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.UserID))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("platform=")
|
||||
builder.WriteString(_m.Platform)
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DailyLimitUsd; v != nil {
|
||||
builder.WriteString("daily_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.WeeklyLimitUsd; v != nil {
|
||||
builder.WriteString("weekly_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.MonthlyLimitUsd; v != nil {
|
||||
builder.WriteString("monthly_limit_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("daily_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.DailyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("weekly_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.WeeklyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("monthly_usage_usd=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.MonthlyUsageUsd))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.DailyWindowStart; v != nil {
|
||||
builder.WriteString("daily_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.WeeklyWindowStart; v != nil {
|
||||
builder.WriteString("weekly_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.MonthlyWindowStart; v != nil {
|
||||
builder.WriteString("monthly_window_start=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// UserPlatformQuotaSlice is a parsable slice of UserPlatformQuota.
|
||||
type UserPlatformQuotaSlice []*UserPlatformQuota
|
||||
202
backend/ent/userplatformquota/userplatformquota.go
Normal file
202
backend/ent/userplatformquota/userplatformquota.go
Normal file
@ -0,0 +1,202 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package userplatformquota
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
)
|
||||
|
||||
const (
|
||||
// Label holds the string label denoting the userplatformquota type in the database.
|
||||
Label = "user_platform_quota"
|
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||
FieldUpdatedAt = "updated_at"
|
||||
// FieldDeletedAt holds the string denoting the deleted_at field in the database.
|
||||
FieldDeletedAt = "deleted_at"
|
||||
// FieldUserID holds the string denoting the user_id field in the database.
|
||||
FieldUserID = "user_id"
|
||||
// FieldPlatform holds the string denoting the platform field in the database.
|
||||
FieldPlatform = "platform"
|
||||
// FieldDailyLimitUsd holds the string denoting the daily_limit_usd field in the database.
|
||||
FieldDailyLimitUsd = "daily_limit_usd"
|
||||
// FieldWeeklyLimitUsd holds the string denoting the weekly_limit_usd field in the database.
|
||||
FieldWeeklyLimitUsd = "weekly_limit_usd"
|
||||
// FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database.
|
||||
FieldMonthlyLimitUsd = "monthly_limit_usd"
|
||||
// FieldDailyUsageUsd holds the string denoting the daily_usage_usd field in the database.
|
||||
FieldDailyUsageUsd = "daily_usage_usd"
|
||||
// FieldWeeklyUsageUsd holds the string denoting the weekly_usage_usd field in the database.
|
||||
FieldWeeklyUsageUsd = "weekly_usage_usd"
|
||||
// FieldMonthlyUsageUsd holds the string denoting the monthly_usage_usd field in the database.
|
||||
FieldMonthlyUsageUsd = "monthly_usage_usd"
|
||||
// FieldDailyWindowStart holds the string denoting the daily_window_start field in the database.
|
||||
FieldDailyWindowStart = "daily_window_start"
|
||||
// FieldWeeklyWindowStart holds the string denoting the weekly_window_start field in the database.
|
||||
FieldWeeklyWindowStart = "weekly_window_start"
|
||||
// FieldMonthlyWindowStart holds the string denoting the monthly_window_start field in the database.
|
||||
FieldMonthlyWindowStart = "monthly_window_start"
|
||||
// EdgeUser holds the string denoting the user edge name in mutations.
|
||||
EdgeUser = "user"
|
||||
// Table holds the table name of the userplatformquota in the database.
|
||||
Table = "user_platform_quotas"
|
||||
// UserTable is the table that holds the user relation/edge.
|
||||
UserTable = "user_platform_quotas"
|
||||
// UserInverseTable is the table name for the User entity.
|
||||
// It exists in this package in order to avoid circular dependency with the "user" package.
|
||||
UserInverseTable = "users"
|
||||
// UserColumn is the table column denoting the user relation/edge.
|
||||
UserColumn = "user_id"
|
||||
)
|
||||
|
||||
// Columns holds all SQL columns for userplatformquota fields.
|
||||
var Columns = []string{
|
||||
FieldID,
|
||||
FieldCreatedAt,
|
||||
FieldUpdatedAt,
|
||||
FieldDeletedAt,
|
||||
FieldUserID,
|
||||
FieldPlatform,
|
||||
FieldDailyLimitUsd,
|
||||
FieldWeeklyLimitUsd,
|
||||
FieldMonthlyLimitUsd,
|
||||
FieldDailyUsageUsd,
|
||||
FieldWeeklyUsageUsd,
|
||||
FieldMonthlyUsageUsd,
|
||||
FieldDailyWindowStart,
|
||||
FieldWeeklyWindowStart,
|
||||
FieldMonthlyWindowStart,
|
||||
}
|
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool {
|
||||
for i := range Columns {
|
||||
if column == Columns[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Note that the variables below are initialized by the runtime
|
||||
// package on the initialization of the application. Therefore,
|
||||
// it should be imported in the main as follows:
|
||||
//
|
||||
// import _ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
var (
|
||||
Hooks [1]ent.Hook
|
||||
Interceptors [1]ent.Interceptor
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
DefaultUpdatedAt func() time.Time
|
||||
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// PlatformValidator is a validator for the "platform" field. It is called by the builders before save.
|
||||
PlatformValidator func(string) error
|
||||
// DefaultDailyUsageUsd holds the default value on creation for the "daily_usage_usd" field.
|
||||
DefaultDailyUsageUsd float64
|
||||
// DefaultWeeklyUsageUsd holds the default value on creation for the "weekly_usage_usd" field.
|
||||
DefaultWeeklyUsageUsd float64
|
||||
// DefaultMonthlyUsageUsd holds the default value on creation for the "monthly_usage_usd" field.
|
||||
DefaultMonthlyUsageUsd float64
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the UserPlatformQuota queries.
|
||||
type OrderOption func(*sql.Selector)
|
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUpdatedAt orders the results by the updated_at field.
|
||||
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDeletedAt orders the results by the deleted_at field.
|
||||
func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDeletedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserID orders the results by the user_id field.
|
||||
func ByUserID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldUserID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByPlatform orders the results by the platform field.
|
||||
func ByPlatform(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldPlatform, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyLimitUsd orders the results by the daily_limit_usd field.
|
||||
func ByDailyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyLimitUsd orders the results by the weekly_limit_usd field.
|
||||
func ByWeeklyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyLimitUsd orders the results by the monthly_limit_usd field.
|
||||
func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyUsageUsd orders the results by the daily_usage_usd field.
|
||||
func ByDailyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyUsageUsd orders the results by the weekly_usage_usd field.
|
||||
func ByWeeklyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyUsageUsd orders the results by the monthly_usage_usd field.
|
||||
func ByMonthlyUsageUsd(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyUsageUsd, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDailyWindowStart orders the results by the daily_window_start field.
|
||||
func ByDailyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDailyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByWeeklyWindowStart orders the results by the weekly_window_start field.
|
||||
func ByWeeklyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldWeeklyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByMonthlyWindowStart orders the results by the monthly_window_start field.
|
||||
func ByMonthlyWindowStart(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldMonthlyWindowStart, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByUserField orders the results by user field.
|
||||
func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
|
||||
}
|
||||
}
|
||||
func newUserStep() *sqlgraph.Step {
|
||||
return sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.To(UserInverseTable, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
}
|
||||
799
backend/ent/userplatformquota/where.go
Normal file
799
backend/ent/userplatformquota/where.go
Normal file
@ -0,0 +1,799 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package userplatformquota
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
)
|
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldID, id))
|
||||
}
|
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldID, ids...))
|
||||
}
|
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldID, id))
|
||||
}
|
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldID, id))
|
||||
}
|
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
|
||||
func UpdatedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ.
|
||||
func DeletedAt(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
|
||||
func UserID(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// Platform applies equality check predicate on the "platform" field. It's identical to PlatformEQ.
|
||||
func Platform(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsd applies equality check predicate on the "daily_limit_usd" field. It's identical to DailyLimitUsdEQ.
|
||||
func DailyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsd applies equality check predicate on the "weekly_limit_usd" field. It's identical to WeeklyLimitUsdEQ.
|
||||
func WeeklyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsd applies equality check predicate on the "monthly_limit_usd" field. It's identical to MonthlyLimitUsdEQ.
|
||||
func MonthlyLimitUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsd applies equality check predicate on the "daily_usage_usd" field. It's identical to DailyUsageUsdEQ.
|
||||
func DailyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsd applies equality check predicate on the "weekly_usage_usd" field. It's identical to WeeklyUsageUsdEQ.
|
||||
func WeeklyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsd applies equality check predicate on the "monthly_usage_usd" field. It's identical to MonthlyUsageUsdEQ.
|
||||
func MonthlyUsageUsd(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyWindowStart applies equality check predicate on the "daily_window_start" field. It's identical to DailyWindowStartEQ.
|
||||
func DailyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStart applies equality check predicate on the "weekly_window_start" field. It's identical to WeeklyWindowStartEQ.
|
||||
func WeeklyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStart applies equality check predicate on the "monthly_window_start" field. It's identical to MonthlyWindowStartEQ.
|
||||
func MonthlyWindowStart(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldCreatedAt, vs...))
|
||||
}
|
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
|
||||
func UpdatedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
|
||||
func UpdatedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtIn applies the In predicate on the "updated_at" field.
|
||||
func UpdatedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
|
||||
func UpdatedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUpdatedAt, vs...))
|
||||
}
|
||||
|
||||
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
|
||||
func UpdatedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
|
||||
func UpdatedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
|
||||
func UpdatedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
|
||||
func UpdatedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldUpdatedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtEQ applies the EQ predicate on the "deleted_at" field.
|
||||
func DeletedAtEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field.
|
||||
func DeletedAtNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtIn applies the In predicate on the "deleted_at" field.
|
||||
func DeletedAtIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDeletedAt, vs...))
|
||||
}
|
||||
|
||||
// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field.
|
||||
func DeletedAtNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDeletedAt, vs...))
|
||||
}
|
||||
|
||||
// DeletedAtGT applies the GT predicate on the "deleted_at" field.
|
||||
func DeletedAtGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtGTE applies the GTE predicate on the "deleted_at" field.
|
||||
func DeletedAtGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtLT applies the LT predicate on the "deleted_at" field.
|
||||
func DeletedAtLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtLTE applies the LTE predicate on the "deleted_at" field.
|
||||
func DeletedAtLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDeletedAt, v))
|
||||
}
|
||||
|
||||
// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field.
|
||||
func DeletedAtIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDeletedAt))
|
||||
}
|
||||
|
||||
// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field.
|
||||
func DeletedAtNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDeletedAt))
|
||||
}
|
||||
|
||||
// UserIDEQ applies the EQ predicate on the "user_id" field.
|
||||
func UserIDEQ(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDNEQ applies the NEQ predicate on the "user_id" field.
|
||||
func UserIDNEQ(v int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldUserID, v))
|
||||
}
|
||||
|
||||
// UserIDIn applies the In predicate on the "user_id" field.
|
||||
func UserIDIn(vs ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// UserIDNotIn applies the NotIn predicate on the "user_id" field.
|
||||
func UserIDNotIn(vs ...int64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldUserID, vs...))
|
||||
}
|
||||
|
||||
// PlatformEQ applies the EQ predicate on the "platform" field.
|
||||
func PlatformEQ(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformNEQ applies the NEQ predicate on the "platform" field.
|
||||
func PlatformNEQ(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformIn applies the In predicate on the "platform" field.
|
||||
func PlatformIn(vs ...string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldPlatform, vs...))
|
||||
}
|
||||
|
||||
// PlatformNotIn applies the NotIn predicate on the "platform" field.
|
||||
func PlatformNotIn(vs ...string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldPlatform, vs...))
|
||||
}
|
||||
|
||||
// PlatformGT applies the GT predicate on the "platform" field.
|
||||
func PlatformGT(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformGTE applies the GTE predicate on the "platform" field.
|
||||
func PlatformGTE(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformLT applies the LT predicate on the "platform" field.
|
||||
func PlatformLT(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformLTE applies the LTE predicate on the "platform" field.
|
||||
func PlatformLTE(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformContains applies the Contains predicate on the "platform" field.
|
||||
func PlatformContains(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldContains(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformHasPrefix applies the HasPrefix predicate on the "platform" field.
|
||||
func PlatformHasPrefix(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldHasPrefix(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformHasSuffix applies the HasSuffix predicate on the "platform" field.
|
||||
func PlatformHasSuffix(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldHasSuffix(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformEqualFold applies the EqualFold predicate on the "platform" field.
|
||||
func PlatformEqualFold(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEqualFold(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// PlatformContainsFold applies the ContainsFold predicate on the "platform" field.
|
||||
func PlatformContainsFold(v string) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldContainsFold(FieldPlatform, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdEQ applies the EQ predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNEQ applies the NEQ predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdIn applies the In predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNotIn applies the NotIn predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyLimitUsdGT applies the GT predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdGTE applies the GTE predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdLT applies the LT predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdLTE applies the LTE predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyLimitUsd, v))
|
||||
}
|
||||
|
||||
// DailyLimitUsdIsNil applies the IsNil predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyLimitUsd))
|
||||
}
|
||||
|
||||
// DailyLimitUsdNotNil applies the NotNil predicate on the "daily_limit_usd" field.
|
||||
func DailyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyLimitUsd))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdEQ applies the EQ predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNEQ applies the NEQ predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdIn applies the In predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNotIn applies the NotIn predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdGT applies the GT predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdGTE applies the GTE predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdLT applies the LT predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdLTE applies the LTE predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyLimitUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdIsNil applies the IsNil predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyLimitUsd))
|
||||
}
|
||||
|
||||
// WeeklyLimitUsdNotNil applies the NotNil predicate on the "weekly_limit_usd" field.
|
||||
func WeeklyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyLimitUsd))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdEQ applies the EQ predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNEQ applies the NEQ predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdIn applies the In predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNotIn applies the NotIn predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyLimitUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdGT applies the GT predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdGTE applies the GTE predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdLT applies the LT predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdLTE applies the LTE predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyLimitUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdIsNil applies the IsNil predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyLimitUsd))
|
||||
}
|
||||
|
||||
// MonthlyLimitUsdNotNil applies the NotNil predicate on the "monthly_limit_usd" field.
|
||||
func MonthlyLimitUsdNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyLimitUsd))
|
||||
}
|
||||
|
||||
// DailyUsageUsdEQ applies the EQ predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdNEQ applies the NEQ predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdIn applies the In predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyUsageUsdNotIn applies the NotIn predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// DailyUsageUsdGT applies the GT predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdGTE applies the GTE predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdLT applies the LT predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyUsageUsdLTE applies the LTE predicate on the "daily_usage_usd" field.
|
||||
func DailyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdEQ applies the EQ predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdNEQ applies the NEQ predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdIn applies the In predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdNotIn applies the NotIn predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdGT applies the GT predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdGTE applies the GTE predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdLT applies the LT predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// WeeklyUsageUsdLTE applies the LTE predicate on the "weekly_usage_usd" field.
|
||||
func WeeklyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdEQ applies the EQ predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdNEQ applies the NEQ predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdNEQ(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdIn applies the In predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdNotIn applies the NotIn predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdNotIn(vs ...float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyUsageUsd, vs...))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdGT applies the GT predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdGT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdGTE applies the GTE predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdGTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdLT applies the LT predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdLT(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// MonthlyUsageUsdLTE applies the LTE predicate on the "monthly_usage_usd" field.
|
||||
func MonthlyUsageUsdLTE(v float64) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyUsageUsd, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartEQ applies the EQ predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartNEQ applies the NEQ predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartIn applies the In predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldDailyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// DailyWindowStartNotIn applies the NotIn predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldDailyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// DailyWindowStartGT applies the GT predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartGTE applies the GTE predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartLT applies the LT predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartLTE applies the LTE predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldDailyWindowStart, v))
|
||||
}
|
||||
|
||||
// DailyWindowStartIsNil applies the IsNil predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldDailyWindowStart))
|
||||
}
|
||||
|
||||
// DailyWindowStartNotNil applies the NotNil predicate on the "daily_window_start" field.
|
||||
func DailyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldDailyWindowStart))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartEQ applies the EQ predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNEQ applies the NEQ predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartIn applies the In predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldWeeklyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNotIn applies the NotIn predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldWeeklyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartGT applies the GT predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartGTE applies the GTE predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartLT applies the LT predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartLTE applies the LTE predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldWeeklyWindowStart, v))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartIsNil applies the IsNil predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldWeeklyWindowStart))
|
||||
}
|
||||
|
||||
// WeeklyWindowStartNotNil applies the NotNil predicate on the "weekly_window_start" field.
|
||||
func WeeklyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldWeeklyWindowStart))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartEQ applies the EQ predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNEQ applies the NEQ predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNEQ(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNEQ(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartIn applies the In predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIn(FieldMonthlyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNotIn applies the NotIn predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNotIn(vs ...time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotIn(FieldMonthlyWindowStart, vs...))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartGT applies the GT predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartGT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGT(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartGTE applies the GTE predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartGTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldGTE(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartLT applies the LT predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartLT(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLT(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartLTE applies the LTE predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartLTE(v time.Time) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldLTE(FieldMonthlyWindowStart, v))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartIsNil applies the IsNil predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartIsNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldIsNull(FieldMonthlyWindowStart))
|
||||
}
|
||||
|
||||
// MonthlyWindowStartNotNil applies the NotNil predicate on the "monthly_window_start" field.
|
||||
func MonthlyWindowStartNotNil() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.FieldNotNull(FieldMonthlyWindowStart))
|
||||
}
|
||||
|
||||
// HasUser applies the HasEdge predicate on the "user" edge.
|
||||
func HasUser() predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(Table, FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
|
||||
)
|
||||
sqlgraph.HasNeighbors(s, step)
|
||||
})
|
||||
}
|
||||
|
||||
// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
|
||||
func HasUserWith(preds ...predicate.User) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(func(s *sql.Selector) {
|
||||
step := newUserStep()
|
||||
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
|
||||
for _, p := range preds {
|
||||
p(s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.AndPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.OrPredicates(predicates...))
|
||||
}
|
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.UserPlatformQuota) predicate.UserPlatformQuota {
|
||||
return predicate.UserPlatformQuota(sql.NotPredicates(p))
|
||||
}
|
||||
1513
backend/ent/userplatformquota_create.go
Normal file
1513
backend/ent/userplatformquota_create.go
Normal file
File diff suppressed because it is too large
Load Diff
88
backend/ent/userplatformquota_delete.go
Normal file
88
backend/ent/userplatformquota_delete.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaDelete is the builder for deleting a UserPlatformQuota entity.
|
||||
type UserPlatformQuotaDelete struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
|
||||
func (_d *UserPlatformQuotaDelete) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDelete {
|
||||
_d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *UserPlatformQuotaDelete) Exec(ctx context.Context) (int, error) {
|
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UserPlatformQuotaDelete) ExecX(ctx context.Context) int {
|
||||
n, err := _d.Exec(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (_d *UserPlatformQuotaDelete) sqlExec(ctx context.Context) (int, error) {
|
||||
_spec := sqlgraph.NewDeleteSpec(userplatformquota.Table, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
if ps := _d.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
|
||||
if err != nil && sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
_d.mutation.done = true
|
||||
return affected, err
|
||||
}
|
||||
|
||||
// UserPlatformQuotaDeleteOne is the builder for deleting a single UserPlatformQuota entity.
|
||||
type UserPlatformQuotaDeleteOne struct {
|
||||
_d *UserPlatformQuotaDelete
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaDelete builder.
|
||||
func (_d *UserPlatformQuotaDeleteOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaDeleteOne {
|
||||
_d._d.mutation.Where(ps...)
|
||||
return _d
|
||||
}
|
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *UserPlatformQuotaDeleteOne) Exec(ctx context.Context) error {
|
||||
n, err := _d._d.Exec(ctx)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
case n == 0:
|
||||
return &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *UserPlatformQuotaDeleteOne) ExecX(ctx context.Context) {
|
||||
if err := _d.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
643
backend/ent/userplatformquota_query.go
Normal file
643
backend/ent/userplatformquota_query.go
Normal file
@ -0,0 +1,643 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"entgo.io/ent"
|
||||
"entgo.io/ent/dialect"
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaQuery is the builder for querying UserPlatformQuota entities.
|
||||
type UserPlatformQuotaQuery struct {
|
||||
config
|
||||
ctx *QueryContext
|
||||
order []userplatformquota.OrderOption
|
||||
inters []Interceptor
|
||||
predicates []predicate.UserPlatformQuota
|
||||
withUser *UserQuery
|
||||
modifiers []func(*sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
}
|
||||
|
||||
// Where adds a new predicate for the UserPlatformQuotaQuery builder.
|
||||
func (_q *UserPlatformQuotaQuery) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaQuery {
|
||||
_q.predicates = append(_q.predicates, ps...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *UserPlatformQuotaQuery) Limit(limit int) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Limit = &limit
|
||||
return _q
|
||||
}
|
||||
|
||||
// Offset to start from.
|
||||
func (_q *UserPlatformQuotaQuery) Offset(offset int) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Offset = &offset
|
||||
return _q
|
||||
}
|
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *UserPlatformQuotaQuery) Unique(unique bool) *UserPlatformQuotaQuery {
|
||||
_q.ctx.Unique = &unique
|
||||
return _q
|
||||
}
|
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *UserPlatformQuotaQuery) Order(o ...userplatformquota.OrderOption) *UserPlatformQuotaQuery {
|
||||
_q.order = append(_q.order, o...)
|
||||
return _q
|
||||
}
|
||||
|
||||
// QueryUser chains the current query on the "user" edge.
|
||||
func (_q *UserPlatformQuotaQuery) QueryUser() *UserQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selector := _q.sqlQuery(ctx)
|
||||
if err := selector.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := sqlgraph.NewStep(
|
||||
sqlgraph.From(userplatformquota.Table, userplatformquota.FieldID, selector),
|
||||
sqlgraph.To(user.Table, user.FieldID),
|
||||
sqlgraph.Edge(sqlgraph.M2O, true, userplatformquota.UserTable, userplatformquota.UserColumn),
|
||||
)
|
||||
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
|
||||
return fromU, nil
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// First returns the first UserPlatformQuota entity from the query.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota was found.
|
||||
func (_q *UserPlatformQuotaQuery) First(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, &NotFoundError{userplatformquota.Label}
|
||||
}
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) FirstX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _q.First(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// FirstID returns the first UserPlatformQuota ID from the query.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota ID was found.
|
||||
func (_q *UserPlatformQuotaQuery) FirstID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
return
|
||||
}
|
||||
return ids[0], nil
|
||||
}
|
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) FirstIDX(ctx context.Context) int64 {
|
||||
id, err := _q.FirstID(ctx)
|
||||
if err != nil && !IsNotFound(err) {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Only returns a single UserPlatformQuota entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one UserPlatformQuota entity is found.
|
||||
// Returns a *NotFoundError when no UserPlatformQuota entities are found.
|
||||
func (_q *UserPlatformQuotaQuery) Only(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch len(nodes) {
|
||||
case 1:
|
||||
return nodes[0], nil
|
||||
case 0:
|
||||
return nil, &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
return nil, &NotSingularError{userplatformquota.Label}
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _q.Only(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// OnlyID is like Only, but returns the only UserPlatformQuota ID in the query.
|
||||
// Returns a *NotSingularError when more than one UserPlatformQuota ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyID(ctx context.Context) (id int64, err error) {
|
||||
var ids []int64
|
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
|
||||
return
|
||||
}
|
||||
switch len(ids) {
|
||||
case 1:
|
||||
id = ids[0]
|
||||
case 0:
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
default:
|
||||
err = &NotSingularError{userplatformquota.Label}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) OnlyIDX(ctx context.Context) int64 {
|
||||
id, err := _q.OnlyID(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// All executes the query and returns a list of UserPlatformQuotaSlice.
|
||||
func (_q *UserPlatformQuotaQuery) All(ctx context.Context) ([]*UserPlatformQuota, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qr := querierAll[[]*UserPlatformQuota, *UserPlatformQuotaQuery]()
|
||||
return withInterceptors[[]*UserPlatformQuota](ctx, _q, qr, _q.inters)
|
||||
}
|
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) AllX(ctx context.Context) []*UserPlatformQuota {
|
||||
nodes, err := _q.All(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
// IDs executes the query and returns a list of UserPlatformQuota IDs.
|
||||
func (_q *UserPlatformQuotaQuery) IDs(ctx context.Context) (ids []int64, err error) {
|
||||
if _q.ctx.Unique == nil && _q.path != nil {
|
||||
_q.Unique(true)
|
||||
}
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
|
||||
if err = _q.Select(userplatformquota.FieldID).Scan(ctx, &ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) IDsX(ctx context.Context) []int64 {
|
||||
ids, err := _q.IDs(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *UserPlatformQuotaQuery) Count(ctx context.Context) (int, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
|
||||
if err := _q.prepareQuery(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withInterceptors[int](ctx, _q, querierCount[*UserPlatformQuotaQuery](), _q.inters)
|
||||
}
|
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) CountX(ctx context.Context) int {
|
||||
count, err := _q.Count(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *UserPlatformQuotaQuery) Exist(ctx context.Context) (bool, error) {
|
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
|
||||
switch _, err := _q.FirstID(ctx); {
|
||||
case IsNotFound(err):
|
||||
return false, nil
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("ent: check existence: %w", err)
|
||||
default:
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *UserPlatformQuotaQuery) ExistX(ctx context.Context) bool {
|
||||
exist, err := _q.Exist(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// Clone returns a duplicate of the UserPlatformQuotaQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *UserPlatformQuotaQuery) Clone() *UserPlatformQuotaQuery {
|
||||
if _q == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserPlatformQuotaQuery{
|
||||
config: _q.config,
|
||||
ctx: _q.ctx.Clone(),
|
||||
order: append([]userplatformquota.OrderOption{}, _q.order...),
|
||||
inters: append([]Interceptor{}, _q.inters...),
|
||||
predicates: append([]predicate.UserPlatformQuota{}, _q.predicates...),
|
||||
withUser: _q.withUser.Clone(),
|
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(),
|
||||
path: _q.path,
|
||||
}
|
||||
}
|
||||
|
||||
// WithUser tells the query-builder to eager-load the nodes that are connected to
|
||||
// the "user" edge. The optional arguments are used to configure the query builder of the edge.
|
||||
func (_q *UserPlatformQuotaQuery) WithUser(opts ...func(*UserQuery)) *UserPlatformQuotaQuery {
|
||||
query := (&UserClient{config: _q.config}).Query()
|
||||
for _, opt := range opts {
|
||||
opt(query)
|
||||
}
|
||||
_q.withUser = query
|
||||
return _q
|
||||
}
|
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UserPlatformQuota.Query().
|
||||
// GroupBy(userplatformquota.FieldCreatedAt).
|
||||
// Aggregate(ent.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UserPlatformQuotaQuery) GroupBy(field string, fields ...string) *UserPlatformQuotaGroupBy {
|
||||
_q.ctx.Fields = append([]string{field}, fields...)
|
||||
grbuild := &UserPlatformQuotaGroupBy{build: _q}
|
||||
grbuild.flds = &_q.ctx.Fields
|
||||
grbuild.label = userplatformquota.Label
|
||||
grbuild.scan = grbuild.Scan
|
||||
return grbuild
|
||||
}
|
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.UserPlatformQuota.Query().
|
||||
// Select(userplatformquota.FieldCreatedAt).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *UserPlatformQuotaQuery) Select(fields ...string) *UserPlatformQuotaSelect {
|
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
|
||||
sbuild := &UserPlatformQuotaSelect{UserPlatformQuotaQuery: _q}
|
||||
sbuild.label = userplatformquota.Label
|
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
|
||||
return sbuild
|
||||
}
|
||||
|
||||
// Aggregate returns a UserPlatformQuotaSelect configured with the given aggregations.
|
||||
func (_q *UserPlatformQuotaQuery) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
|
||||
return _q.Select().Aggregate(fns...)
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) prepareQuery(ctx context.Context) error {
|
||||
for _, inter := range _q.inters {
|
||||
if inter == nil {
|
||||
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
|
||||
}
|
||||
if trv, ok := inter.(Traverser); ok {
|
||||
if err := trv.Traverse(ctx, _q); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, f := range _q.ctx.Fields {
|
||||
if !userplatformquota.ValidColumn(f) {
|
||||
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
}
|
||||
if _q.path != nil {
|
||||
prev, err := _q.path(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_q.sql = prev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserPlatformQuota, error) {
|
||||
var (
|
||||
nodes = []*UserPlatformQuota{}
|
||||
_spec = _q.querySpec()
|
||||
loadedTypes = [1]bool{
|
||||
_q.withUser != nil,
|
||||
}
|
||||
)
|
||||
_spec.ScanValues = func(columns []string) ([]any, error) {
|
||||
return (*UserPlatformQuota).scanValues(nil, columns)
|
||||
}
|
||||
_spec.Assign = func(columns []string, values []any) error {
|
||||
node := &UserPlatformQuota{config: _q.config}
|
||||
nodes = append(nodes, node)
|
||||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
for i := range hooks {
|
||||
hooks[i](ctx, _spec)
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nodes, nil
|
||||
}
|
||||
if query := _q.withUser; query != nil {
|
||||
if err := _q.loadUser(ctx, query, nodes, nil,
|
||||
func(n *UserPlatformQuota, e *User) { n.Edges.User = e }); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserPlatformQuota, init func(*UserPlatformQuota), assign func(*UserPlatformQuota, *User)) error {
|
||||
ids := make([]int64, 0, len(nodes))
|
||||
nodeids := make(map[int64][]*UserPlatformQuota)
|
||||
for i := range nodes {
|
||||
fk := nodes[i].UserID
|
||||
if _, ok := nodeids[fk]; !ok {
|
||||
ids = append(ids, fk)
|
||||
}
|
||||
nodeids[fk] = append(nodeids[fk], nodes[i])
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
query.Where(user.IDIn(ids...))
|
||||
neighbors, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, n := range neighbors {
|
||||
nodes, ok := nodeids[n.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
|
||||
}
|
||||
for i := range nodes {
|
||||
assign(nodes[i], n)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := _q.querySpec()
|
||||
if len(_q.modifiers) > 0 {
|
||||
_spec.Modifiers = _q.modifiers
|
||||
}
|
||||
_spec.Node.Columns = _q.ctx.Fields
|
||||
if len(_q.ctx.Fields) > 0 {
|
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
|
||||
}
|
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) querySpec() *sqlgraph.QuerySpec {
|
||||
_spec := sqlgraph.NewQuerySpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
_spec.From = _q.sql
|
||||
if unique := _q.ctx.Unique; unique != nil {
|
||||
_spec.Unique = *unique
|
||||
} else if _q.path != nil {
|
||||
_spec.Unique = true
|
||||
}
|
||||
if fields := _q.ctx.Fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
|
||||
for i := range fields {
|
||||
if fields[i] != userplatformquota.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
|
||||
}
|
||||
}
|
||||
if _q.withUser != nil {
|
||||
_spec.Node.AddColumnOnce(userplatformquota.FieldUserID)
|
||||
}
|
||||
}
|
||||
if ps := _q.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
_spec.Limit = *limit
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
_spec.Offset = *offset
|
||||
}
|
||||
if ps := _q.order; len(ps) > 0 {
|
||||
_spec.Order = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
return _spec
|
||||
}
|
||||
|
||||
func (_q *UserPlatformQuotaQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
||||
builder := sql.Dialect(_q.driver.Dialect())
|
||||
t1 := builder.Table(userplatformquota.Table)
|
||||
columns := _q.ctx.Fields
|
||||
if len(columns) == 0 {
|
||||
columns = userplatformquota.Columns
|
||||
}
|
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1)
|
||||
if _q.sql != nil {
|
||||
selector = _q.sql
|
||||
selector.Select(selector.Columns(columns...)...)
|
||||
}
|
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range _q.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range _q.predicates {
|
||||
p(selector)
|
||||
}
|
||||
for _, p := range _q.order {
|
||||
p(selector)
|
||||
}
|
||||
if offset := _q.ctx.Offset; offset != nil {
|
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32)
|
||||
}
|
||||
if limit := _q.ctx.Limit; limit != nil {
|
||||
selector.Limit(*limit)
|
||||
}
|
||||
return selector
|
||||
}
|
||||
|
||||
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
|
||||
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
|
||||
// either committed or rolled-back.
|
||||
func (_q *UserPlatformQuotaQuery) ForUpdate(opts ...sql.LockOption) *UserPlatformQuotaQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForUpdate(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
|
||||
// on any rows that are read. Other sessions can read the rows, but cannot modify them
|
||||
// until your transaction commits.
|
||||
func (_q *UserPlatformQuotaQuery) ForShare(opts ...sql.LockOption) *UserPlatformQuotaQuery {
|
||||
if _q.driver.Dialect() == dialect.Postgres {
|
||||
_q.Unique(false)
|
||||
}
|
||||
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
|
||||
s.ForShare(opts...)
|
||||
})
|
||||
return _q
|
||||
}
|
||||
|
||||
// UserPlatformQuotaGroupBy is the group-by builder for UserPlatformQuota entities.
|
||||
type UserPlatformQuotaGroupBy struct {
|
||||
selector
|
||||
build *UserPlatformQuotaQuery
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *UserPlatformQuotaGroupBy) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaGroupBy {
|
||||
_g.fns = append(_g.fns, fns...)
|
||||
return _g
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *UserPlatformQuotaGroupBy) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
|
||||
if err := _g.build.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaGroupBy](ctx, _g.build, _g, _g.build.inters, v)
|
||||
}
|
||||
|
||||
func (_g *UserPlatformQuotaGroupBy) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx).Select()
|
||||
aggregation := make([]string, 0, len(_g.fns))
|
||||
for _, fn := range _g.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
if len(selector.SelectedColumns()) == 0 {
|
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
|
||||
for _, f := range *_g.flds {
|
||||
columns = append(columns, selector.C(f))
|
||||
}
|
||||
columns = append(columns, aggregation...)
|
||||
selector.Select(columns...)
|
||||
}
|
||||
selector.GroupBy(selector.Columns(*_g.flds...)...)
|
||||
if err := selector.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// UserPlatformQuotaSelect is the builder for selecting fields of UserPlatformQuota entities.
|
||||
type UserPlatformQuotaSelect struct {
|
||||
*UserPlatformQuotaQuery
|
||||
selector
|
||||
}
|
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *UserPlatformQuotaSelect) Aggregate(fns ...AggregateFunc) *UserPlatformQuotaSelect {
|
||||
_s.fns = append(_s.fns, fns...)
|
||||
return _s
|
||||
}
|
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *UserPlatformQuotaSelect) Scan(ctx context.Context, v any) error {
|
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
|
||||
if err := _s.prepareQuery(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return scanWithInterceptors[*UserPlatformQuotaQuery, *UserPlatformQuotaSelect](ctx, _s.UserPlatformQuotaQuery, _s, _s.inters, v)
|
||||
}
|
||||
|
||||
func (_s *UserPlatformQuotaSelect) sqlScan(ctx context.Context, root *UserPlatformQuotaQuery, v any) error {
|
||||
selector := root.sqlQuery(ctx)
|
||||
aggregation := make([]string, 0, len(_s.fns))
|
||||
for _, fn := range _s.fns {
|
||||
aggregation = append(aggregation, fn(selector))
|
||||
}
|
||||
switch n := len(*_s.selector.flds); {
|
||||
case n == 0 && len(aggregation) > 0:
|
||||
selector.Select(aggregation...)
|
||||
case n != 0 && len(aggregation) > 0:
|
||||
selector.AppendSelect(aggregation...)
|
||||
}
|
||||
rows := &sql.Rows{}
|
||||
query, args := selector.Query()
|
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
985
backend/ent/userplatformquota_update.go
Normal file
985
backend/ent/userplatformquota_update.go
Normal file
@ -0,0 +1,985 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"entgo.io/ent/dialect/sql/sqlgraph"
|
||||
"entgo.io/ent/schema/field"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaUpdate is the builder for updating UserPlatformQuota entities.
|
||||
type UserPlatformQuotaUpdate struct {
|
||||
config
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
|
||||
func (_u *UserPlatformQuotaUpdate) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedAt sets the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetDeletedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDeletedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDeletedAt() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDeletedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUserID(v int64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableUserID(v *int64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatform sets the "platform" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetPlatform(v string) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetPlatform(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePlatform sets the "platform" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillablePlatform(v *string) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetPlatform(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyLimitUsd sets the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetDailyLimitUsd()
|
||||
_u.mutation.SetDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDailyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDailyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetWeeklyLimitUsd()
|
||||
_u.mutation.SetWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearWeeklyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetMonthlyLimitUsd()
|
||||
_u.mutation.SetMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearMonthlyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyUsageUsd sets the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetDailyUsageUsd()
|
||||
_u.mutation.SetDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetWeeklyUsageUsd()
|
||||
_u.mutation.SetWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ResetMonthlyUsageUsd()
|
||||
_u.mutation.SetMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdate) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.AddMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyWindowStart sets the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetDailyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetDailyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearDailyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearDailyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyWindowStart sets the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetWeeklyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetWeeklyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearWeeklyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearWeeklyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyWindowStart sets the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdate {
|
||||
_u.mutation.SetMonthlyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdate) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdate {
|
||||
if v != nil {
|
||||
_u.SetMonthlyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearMonthlyWindowStart() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearMonthlyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdate) SetUser(v *User) *UserPlatformQuotaUpdate {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the UserPlatformQuotaMutation object of the builder.
|
||||
func (_u *UserPlatformQuotaUpdate) Mutation() *UserPlatformQuotaMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdate) ClearUser() *UserPlatformQuotaUpdate {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *UserPlatformQuotaUpdate) Save(ctx context.Context) (int, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdate) SaveX(ctx context.Context) int {
|
||||
affected, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return affected
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (_u *UserPlatformQuotaUpdate) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdate) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UserPlatformQuotaUpdate) defaults() error {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
if userplatformquota.UpdateDefaultUpdatedAt == nil {
|
||||
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||
}
|
||||
v := userplatformquota.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UserPlatformQuotaUpdate) check() error {
|
||||
if v, ok := _u.mutation.Platform(); ok {
|
||||
if err := userplatformquota.PlatformValidator(v); err != nil {
|
||||
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UserPlatformQuotaUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DeletedAtCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Platform(); ok {
|
||||
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.DailyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.WeeklyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.MonthlyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DailyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.WeeklyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.MonthlyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
|
||||
// UserPlatformQuotaUpdateOne is the builder for updating a single UserPlatformQuota entity.
|
||||
type UserPlatformQuotaUpdateOne struct {
|
||||
config
|
||||
fields []string
|
||||
hooks []Hook
|
||||
mutation *UserPlatformQuotaMutation
|
||||
}
|
||||
|
||||
// SetUpdatedAt sets the "updated_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUpdatedAt(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDeletedAt sets the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDeletedAt(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetDeletedAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDeletedAt(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDeletedAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDeletedAt clears the value of the "deleted_at" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDeletedAt() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDeletedAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUserID sets the "user_id" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUserID(v int64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetUserID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableUserID sets the "user_id" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableUserID(v *int64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetUserID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetPlatform sets the "platform" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetPlatform(v string) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetPlatform(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillablePlatform sets the "platform" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillablePlatform(v *string) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetPlatform(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyLimitUsd sets the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetDailyLimitUsd()
|
||||
_u.mutation.SetDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyLimitUsd sets the "daily_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyLimitUsd adds value to the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddDailyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddDailyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDailyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDailyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyLimitUsd sets the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetWeeklyLimitUsd()
|
||||
_u.mutation.SetWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyLimitUsd sets the "weekly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyLimitUsd adds value to the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddWeeklyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearWeeklyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyLimitUsd sets the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetMonthlyLimitUsd()
|
||||
_u.mutation.SetMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyLimitUsd sets the "monthly_limit_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyLimitUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyLimitUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyLimitUsd adds value to the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyLimitUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddMonthlyLimitUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyLimitUsd() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearMonthlyLimitUsd()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyUsageUsd sets the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetDailyUsageUsd()
|
||||
_u.mutation.SetDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyUsageUsd sets the "daily_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddDailyUsageUsd adds value to the "daily_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddDailyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddDailyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyUsageUsd sets the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetWeeklyUsageUsd()
|
||||
_u.mutation.SetWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyUsageUsd sets the "weekly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddWeeklyUsageUsd adds value to the "weekly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddWeeklyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddWeeklyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyUsageUsd sets the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ResetMonthlyUsageUsd()
|
||||
_u.mutation.SetMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyUsageUsd sets the "monthly_usage_usd" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyUsageUsd(v *float64) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyUsageUsd(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddMonthlyUsageUsd adds value to the "monthly_usage_usd" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) AddMonthlyUsageUsd(v float64) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.AddMonthlyUsageUsd(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDailyWindowStart sets the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetDailyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetDailyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableDailyWindowStart sets the "daily_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableDailyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetDailyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearDailyWindowStart clears the value of the "daily_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearDailyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearDailyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetWeeklyWindowStart sets the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetWeeklyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetWeeklyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableWeeklyWindowStart sets the "weekly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableWeeklyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetWeeklyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearWeeklyWindowStart clears the value of the "weekly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearWeeklyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearWeeklyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetMonthlyWindowStart sets the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetMonthlyWindowStart(v time.Time) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.SetMonthlyWindowStart(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableMonthlyWindowStart sets the "monthly_window_start" field if the given value is not nil.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetNillableMonthlyWindowStart(v *time.Time) *UserPlatformQuotaUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetMonthlyWindowStart(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearMonthlyWindowStart clears the value of the "monthly_window_start" field.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearMonthlyWindowStart() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearMonthlyWindowStart()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetUser sets the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SetUser(v *User) *UserPlatformQuotaUpdateOne {
|
||||
return _u.SetUserID(v.ID)
|
||||
}
|
||||
|
||||
// Mutation returns the UserPlatformQuotaMutation object of the builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Mutation() *UserPlatformQuotaMutation {
|
||||
return _u.mutation
|
||||
}
|
||||
|
||||
// ClearUser clears the "user" edge to the User entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ClearUser() *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.ClearUser()
|
||||
return _u
|
||||
}
|
||||
|
||||
// Where appends a list predicates to the UserPlatformQuotaUpdate builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Where(ps ...predicate.UserPlatformQuota) *UserPlatformQuotaUpdateOne {
|
||||
_u.mutation.Where(ps...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Select(field string, fields ...string) *UserPlatformQuotaUpdateOne {
|
||||
_u.fields = append([]string{field}, fields...)
|
||||
return _u
|
||||
}
|
||||
|
||||
// Save executes the query and returns the updated UserPlatformQuota entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Save(ctx context.Context) (*UserPlatformQuota, error) {
|
||||
if err := _u.defaults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
|
||||
}
|
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdateOne) SaveX(ctx context.Context) *UserPlatformQuota {
|
||||
node, err := _u.Save(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *UserPlatformQuotaUpdateOne) Exec(ctx context.Context) error {
|
||||
_, err := _u.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *UserPlatformQuotaUpdateOne) ExecX(ctx context.Context) {
|
||||
if err := _u.Exec(ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_u *UserPlatformQuotaUpdateOne) defaults() error {
|
||||
if _, ok := _u.mutation.UpdatedAt(); !ok {
|
||||
if userplatformquota.UpdateDefaultUpdatedAt == nil {
|
||||
return fmt.Errorf("ent: uninitialized userplatformquota.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)")
|
||||
}
|
||||
v := userplatformquota.UpdateDefaultUpdatedAt()
|
||||
_u.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_u *UserPlatformQuotaUpdateOne) check() error {
|
||||
if v, ok := _u.mutation.Platform(); ok {
|
||||
if err := userplatformquota.PlatformValidator(v); err != nil {
|
||||
return &ValidationError{Name: "platform", err: fmt.Errorf(`ent: validator failed for field "UserPlatformQuota.platform": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UserPlatformQuota.user"`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (_u *UserPlatformQuotaUpdateOne) sqlSave(ctx context.Context) (_node *UserPlatformQuota, err error) {
|
||||
if err := _u.check(); err != nil {
|
||||
return _node, err
|
||||
}
|
||||
_spec := sqlgraph.NewUpdateSpec(userplatformquota.Table, userplatformquota.Columns, sqlgraph.NewFieldSpec(userplatformquota.FieldID, field.TypeInt64))
|
||||
id, ok := _u.mutation.ID()
|
||||
if !ok {
|
||||
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UserPlatformQuota.id" for update`)}
|
||||
}
|
||||
_spec.Node.ID.Value = id
|
||||
if fields := _u.fields; len(fields) > 0 {
|
||||
_spec.Node.Columns = make([]string, 0, len(fields))
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, userplatformquota.FieldID)
|
||||
for _, f := range fields {
|
||||
if !userplatformquota.ValidColumn(f) {
|
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
|
||||
}
|
||||
if f != userplatformquota.FieldID {
|
||||
_spec.Node.Columns = append(_spec.Node.Columns, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ps := _u.mutation.predicates; len(ps) > 0 {
|
||||
_spec.Predicate = func(selector *sql.Selector) {
|
||||
for i := range ps {
|
||||
ps[i](selector)
|
||||
}
|
||||
}
|
||||
}
|
||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldUpdatedAt, field.TypeTime, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DeletedAt(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDeletedAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DeletedAtCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDeletedAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.Platform(); ok {
|
||||
_spec.SetField(userplatformquota.FieldPlatform, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.DailyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.WeeklyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyLimitUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyLimitUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if _u.mutation.MonthlyLimitUsdCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyLimitUsd, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedDailyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldDailyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedWeeklyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldWeeklyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyUsageUsd(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedMonthlyUsageUsd(); ok {
|
||||
_spec.AddField(userplatformquota.FieldMonthlyUsageUsd, field.TypeFloat64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.DailyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldDailyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.DailyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldDailyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.WeeklyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.WeeklyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldWeeklyWindowStart, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.MonthlyWindowStart(); ok {
|
||||
_spec.SetField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.MonthlyWindowStartCleared() {
|
||||
_spec.ClearField(userplatformquota.FieldMonthlyWindowStart, field.TypeTime)
|
||||
}
|
||||
if _u.mutation.UserCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
|
||||
}
|
||||
if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.M2O,
|
||||
Inverse: true,
|
||||
Table: userplatformquota.UserTable,
|
||||
Columns: []string{userplatformquota.UserColumn},
|
||||
Bidi: false,
|
||||
Target: &sqlgraph.EdgeTarget{
|
||||
IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
|
||||
},
|
||||
}
|
||||
for _, k := range nodes {
|
||||
edge.Target.Nodes = append(edge.Target.Nodes, k)
|
||||
}
|
||||
_spec.Edges.Add = append(_spec.Edges.Add, edge)
|
||||
}
|
||||
_node = &UserPlatformQuota{config: _u.config}
|
||||
_spec.Assign = _node.assignValues
|
||||
_spec.ScanValues = _node.scanValues
|
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
|
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok {
|
||||
err = &NotFoundError{userplatformquota.Label}
|
||||
} else if sqlgraph.IsConstraintError(err) {
|
||||
err = &ConstraintError{msg: err.Error(), wrap: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_u.mutation.done = true
|
||||
return _node, nil
|
||||
}
|
||||
@ -5,12 +5,14 @@ go 1.26.3
|
||||
require (
|
||||
entgo.io/ent v0.14.5
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/alicebob/miniredis/v2 v2.38.0
|
||||
github.com/alitto/pond/v2 v2.6.2
|
||||
github.com/andybalholm/brotli v1.2.0
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.10
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
|
||||
github.com/aws/smithy-go v1.24.2
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/dgraph-io/ristretto v0.2.0
|
||||
@ -71,7 +73,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/bmatcuk/doublestar v1.3.4 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
@ -158,6 +159,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
github.com/zclconf/go-cty v1.14.4 // indirect
|
||||
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
|
||||
|
||||
@ -16,6 +16,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l
|
||||
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
|
||||
github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw=
|
||||
github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0 h1:nZAzCR+Lj+Vxk4ZXzm2NuKq2O33RXj1XxJ2e2uP9jiw=
|
||||
github.com/alicebob/miniredis/v2 v2.38.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
|
||||
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
@ -162,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@ -370,6 +370,8 @@ github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52H
|
||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
github.com/zclconf/go-cty v1.14.4 h1:uXXczd9QDGsgu0i/QFR/hzI5NYCHLf6NQw/atrbnhq8=
|
||||
|
||||
@ -643,6 +643,12 @@ type ProxyProbeConfig struct {
|
||||
|
||||
type BillingConfig struct {
|
||||
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
|
||||
// UserPlatformQuotaCacheTTLSeconds 用户 × 平台 quota 缓存 TTL(秒),默认 86400=1天,覆盖典型 daily 窗口。
|
||||
// 消费点:
|
||||
// - billing_cache_service.cacheWriteWorker 异步累加
|
||||
// - billing_cache_service.checkUserPlatformQuotaEligibility 首次缓存装载
|
||||
// 读写两端必须共用同一 TTL,避免缓存生命周期不一致导致 quota 计数漂移。
|
||||
UserPlatformQuotaCacheTTLSeconds int `mapstructure:"user_platform_quota_cache_ttl_seconds"`
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
@ -1544,6 +1550,7 @@ func setDefaults() {
|
||||
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
|
||||
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
|
||||
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
|
||||
viper.SetDefault("billing.user_platform_quota_cache_ttl_seconds", 86400)
|
||||
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
userHandler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc, nil, nil)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
|
||||
@ -24,6 +24,7 @@ type stubAdminService struct {
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
getUserErr error
|
||||
createAccountErr error
|
||||
updateAccountErr error
|
||||
bulkUpdateAccountErr error
|
||||
@ -147,6 +148,9 @@ func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
|
||||
if s.getUserErr != nil {
|
||||
return nil, s.getUserErr
|
||||
}
|
||||
for i := range s.users {
|
||||
if s.users[i].ID == id {
|
||||
return &s.users[i], nil
|
||||
|
||||
@ -305,6 +305,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
// Default platform quotas(JSON map)
|
||||
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
|
||||
slog.Error("default_platform_quotas_get_failed", "error", err)
|
||||
} else {
|
||||
payload.DefaultPlatformQuotas = platformQuotas
|
||||
}
|
||||
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
@ -637,6 +644,18 @@ type UpdateSettingsRequest struct {
|
||||
|
||||
// OpenAI fast/flex policy (optional, only updated when provided)
|
||||
OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
|
||||
// 系统全局 platform quota 默认值(整体替换语义:nil = 不修改,non-nil = 整体覆盖)。
|
||||
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas"`
|
||||
|
||||
// auth-source 层 platform quota 覆盖(override 语义:nil = 不修改,non-nil = 整体覆盖该 source 的 quota 配置)。
|
||||
AuthSourceEmailPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_email_platform_quotas"`
|
||||
AuthSourceLinuxDoPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_linuxdo_platform_quotas"`
|
||||
AuthSourceOIDCPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_oidc_platform_quotas"`
|
||||
AuthSourceWeChatPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_wechat_platform_quotas"`
|
||||
AuthSourceGitHubPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_github_platform_quotas"`
|
||||
AuthSourceGooglePlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_google_platform_quotas"`
|
||||
AuthSourceDingTalkPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"auth_source_default_dingtalk_platform_quotas"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@ -1438,6 +1457,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
// 系统全局 platform quota 默认值(整体替换语义)
|
||||
DefaultPlatformQuotas: req.DefaultPlatformQuotas,
|
||||
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
@ -1731,6 +1753,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}(),
|
||||
}
|
||||
|
||||
// req.AuthSourceXxxPlatformQuotas 为 nil 表示本次请求未包含该 source 的 quota 配置(保留 previousAuthSourceDefaults 中的值);
|
||||
// non-nil(含 empty map)表示整体覆盖:empty map = 清空该 source 的所有 quota 配置。
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
Email: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
|
||||
@ -1738,6 +1762,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceEmailPlatformQuotas, previousAuthSourceDefaults.Email.PlatformQuotas),
|
||||
},
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
|
||||
@ -1745,6 +1770,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceLinuxDoPlatformQuotas, previousAuthSourceDefaults.LinuxDo.PlatformQuotas),
|
||||
},
|
||||
OIDC: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
|
||||
@ -1752,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceOIDCPlatformQuotas, previousAuthSourceDefaults.OIDC.PlatformQuotas),
|
||||
},
|
||||
WeChat: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
|
||||
@ -1759,6 +1786,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceWeChatPlatformQuotas, previousAuthSourceDefaults.WeChat.PlatformQuotas),
|
||||
},
|
||||
GitHub: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGitHubBalance, previousAuthSourceDefaults.GitHub.Balance),
|
||||
@ -1766,6 +1794,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGitHubSubscriptions, previousAuthSourceDefaults.GitHub.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnSignup, previousAuthSourceDefaults.GitHub.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGitHubGrantOnFirstBind, previousAuthSourceDefaults.GitHub.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGitHubPlatformQuotas, previousAuthSourceDefaults.GitHub.PlatformQuotas),
|
||||
},
|
||||
Google: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultGoogleBalance, previousAuthSourceDefaults.Google.Balance),
|
||||
@ -1773,6 +1802,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultGoogleSubscriptions, previousAuthSourceDefaults.Google.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceGooglePlatformQuotas, previousAuthSourceDefaults.Google.PlatformQuotas),
|
||||
},
|
||||
DingTalk: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
|
||||
@ -1780,6 +1810,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
|
||||
PlatformQuotas: platformQuotasValueOrDefault(req.AuthSourceDingTalkPlatformQuotas, previousAuthSourceDefaults.DingTalk.PlatformQuotas),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
@ -2047,6 +2078,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
} else if fastPolicy != nil {
|
||||
payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
|
||||
}
|
||||
|
||||
// Default platform quotas(JSON map)—— 与 GetSettings 一致,避免保存后响应缺失该字段
|
||||
if platformQuotas, err := h.settingService.GetDefaultPlatformQuotas(c.Request.Context()); err != nil {
|
||||
slog.Error("default_platform_quotas_get_failed", "error", err)
|
||||
} else {
|
||||
payload.DefaultPlatformQuotas = platformQuotas
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
@ -2511,6 +2549,10 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.RiskControlEnabled != after.RiskControlEnabled {
|
||||
changed = append(changed, "risk_control_enabled")
|
||||
}
|
||||
// Default platform quotas(JSON map,整体比较)
|
||||
if !equalPlatformQuotaSettings(before.DefaultPlatformQuotas, after.DefaultPlatformQuotas) {
|
||||
changed = append(changed, service.SettingKeyDefaultPlatformQuotas)
|
||||
}
|
||||
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
|
||||
return changed
|
||||
}
|
||||
@ -2554,6 +2596,10 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
|
||||
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
|
||||
}
|
||||
// Platform quotas diff:整体替换语义,发单个 JSON key。
|
||||
if !equalPlatformQuotaSettings(field.before.PlatformQuotas, field.after.PlatformQuotas) {
|
||||
changed = append(changed, service.SettingKeyAuthSourcePlatformQuotas(field.name))
|
||||
}
|
||||
}
|
||||
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
|
||||
changed = append(changed, "force_email_on_third_party_signup")
|
||||
@ -2621,6 +2667,17 @@ func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting,
|
||||
return result
|
||||
}
|
||||
|
||||
// platformQuotasValueOrDefault 处理 auth-source platform quota 的 nil 语义:
|
||||
// nil = 请求未包含该字段(保留 fallback),non-nil(含 empty map)= 整体覆盖。
|
||||
// 注意:JSON null 与字段省略等价——两者均反序列化为 nil map,因此都保留旧值;
|
||||
// 若要清空某 source 的所有 quota 配置,须显式发空对象 {}。
|
||||
func platformQuotasValueOrDefault(value, fallback map[string]*service.DefaultPlatformQuotaSetting) map[string]*service.DefaultPlatformQuotaSetting {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
|
||||
data := make(map[string]any)
|
||||
raw, err := json.Marshal(settings)
|
||||
@ -2666,6 +2723,13 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_google_subscriptions"] = authSourceDefaults.Google.Subscriptions
|
||||
data["auth_source_default_google_grant_on_signup"] = authSourceDefaults.Google.GrantOnSignup
|
||||
data["auth_source_default_google_grant_on_first_bind"] = authSourceDefaults.Google.GrantOnFirstBind
|
||||
data["auth_source_default_email_platform_quotas"] = authSourceDefaults.Email.PlatformQuotas
|
||||
data["auth_source_default_linuxdo_platform_quotas"] = authSourceDefaults.LinuxDo.PlatformQuotas
|
||||
data["auth_source_default_oidc_platform_quotas"] = authSourceDefaults.OIDC.PlatformQuotas
|
||||
data["auth_source_default_wechat_platform_quotas"] = authSourceDefaults.WeChat.PlatformQuotas
|
||||
data["auth_source_default_github_platform_quotas"] = authSourceDefaults.GitHub.PlatformQuotas
|
||||
data["auth_source_default_google_platform_quotas"] = authSourceDefaults.Google.PlatformQuotas
|
||||
data["auth_source_default_dingtalk_platform_quotas"] = authSourceDefaults.DingTalk.PlatformQuotas
|
||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||
|
||||
return data
|
||||
@ -3552,3 +3616,48 @@ func emailTemplatePlaceholderUnion(events []service.NotificationEmailEventInfo)
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
|
||||
// equalNullableFloat compares two *float64 values treating nil as a distinct case.
|
||||
func equalNullableFloat(a, b *float64) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
// slotOf returns the *float64 for the given window from a DefaultPlatformQuotaSetting.
|
||||
func slotOf(s *service.DefaultPlatformQuotaSetting, win string) *float64 {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
switch win {
|
||||
case "daily":
|
||||
return s.DailyLimitUSD
|
||||
case "weekly":
|
||||
return s.WeeklyLimitUSD
|
||||
case "monthly":
|
||||
return s.MonthlyLimitUSD
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// equalPlatformQuotaSettings reports whether two platform-quota maps are identical across all 12 slots.
|
||||
func equalPlatformQuotaSettings(before, after map[string]*service.DefaultPlatformQuotaSetting) bool {
|
||||
for _, platform := range service.AllowedQuotaPlatforms {
|
||||
b := before[platform]
|
||||
a := after[platform]
|
||||
if !equalNullableFloat(slotOf(b, "daily"), slotOf(a, "daily")) {
|
||||
return false
|
||||
}
|
||||
if !equalNullableFloat(slotOf(b, "weekly"), slotOf(a, "weekly")) {
|
||||
return false
|
||||
}
|
||||
if !equalNullableFloat(slotOf(b, "monthly"), slotOf(a, "monthly")) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@ -0,0 +1,188 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDiffSettings_DetectsGlobalPlatformQuotaChange(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
after := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
},
|
||||
}
|
||||
|
||||
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
|
||||
found := false
|
||||
for _, key := range changed {
|
||||
if key == service.SettingKeyDefaultPlatformQuotas {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected change detection for default platform quotas, got %v", changed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffSettings_NoChangeWhenEqual(t *testing.T) {
|
||||
five := 5.0
|
||||
before := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
after := &service.SystemSettings{
|
||||
DefaultPlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
|
||||
changed := diffSettings(before, after, nil, nil, UpdateSettingsRequest{})
|
||||
for _, key := range changed {
|
||||
if key == service.SettingKeyDefaultPlatformQuotas {
|
||||
t.Error("equal values should not be detected as changed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualNullableFloat(t *testing.T) {
|
||||
five := 5.0
|
||||
five2 := 5.0
|
||||
ten := 10.0
|
||||
cases := []struct {
|
||||
a, b *float64
|
||||
want bool
|
||||
}{
|
||||
{nil, nil, true},
|
||||
{&five, nil, false},
|
||||
{nil, &five, false},
|
||||
{&five, &five2, true},
|
||||
{&five, &ten, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := equalNullableFloat(c.a, c.b); got != c.want {
|
||||
t.Errorf("equalNullableFloat(%v, %v) = %v, want %v", c.a, c.b, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualPlatformQuotaSettings_DetectsPerWindowChange(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
}
|
||||
after := map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
}
|
||||
if equalPlatformQuotaSettings(before, after) {
|
||||
t.Error("expected unequal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendAuthSourceDefaultChanges_DetectsPerWindow(t *testing.T) {
|
||||
five := 5.0
|
||||
ten := 10.0
|
||||
before := &service.AuthSourceDefaultSettings{
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
},
|
||||
}
|
||||
after := &service.AuthSourceDefaultSettings{
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
PlatformQuotas: map[string]*service.DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &ten},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed := appendAuthSourceDefaultChanges([]string{}, before, after)
|
||||
// 改动 B5:整体替换语义,审计 log 发单个 JSON key,而非展开 84 个扁平 key。
|
||||
key := service.SettingKeyAuthSourcePlatformQuotas("linuxdo")
|
||||
found := false
|
||||
for _, k := range changed {
|
||||
if k == key {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected %q in changed, got %v", key, changed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip 验证 Bug A 修复:
|
||||
// PUT 发 auth_source_default_email_platform_quotas,GET 能读回相同值(端到端往返)。
|
||||
func TestSettingHandler_AuthSourcePlatformQuotas_PutGetRoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &settingHandlerRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
// PUT:发 email platform quota(openai monthly=20)
|
||||
putBody := map[string]any{
|
||||
"auth_source_default_email_platform_quotas": map[string]any{
|
||||
"openai": map[string]any{
|
||||
"monthly": 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
rawBody, err := json.Marshal(putBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
handler.UpdateSettings(c)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 验证 DB 中写入了 JSON key
|
||||
jsonKey := service.SettingKeyAuthSourcePlatformQuotas("email")
|
||||
require.NotEmpty(t, repo.values[jsonKey], "expected JSON key to be written to DB")
|
||||
|
||||
// GET:验证响应中 auth_source_default_email_platform_quotas.openai.monthly = 20
|
||||
rec2 := httptest.NewRecorder()
|
||||
c2, _ := gin.CreateTestContext(rec2)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
|
||||
handler.GetSettings(c2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
emailPQ, ok := data["auth_source_default_email_platform_quotas"].(map[string]any)
|
||||
require.True(t, ok, "expected auth_source_default_email_platform_quotas to be a map")
|
||||
openaiPQ, ok := emailPQ["openai"].(map[string]any)
|
||||
require.True(t, ok, "expected openai entry in email platform quotas")
|
||||
monthly, ok := openaiPQ["monthly"].(float64)
|
||||
require.True(t, ok, "expected monthly to be float64")
|
||||
require.Equal(t, float64(20), monthly, "expected openai monthly=20")
|
||||
}
|
||||
@ -2,10 +2,15 @@ package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@ -20,15 +25,24 @@ type UserWithConcurrency struct {
|
||||
|
||||
// UserHandler handles admin user management
|
||||
type UserHandler struct {
|
||||
adminService service.AdminService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
adminService service.AdminService
|
||||
concurrencyService *service.ConcurrencyService
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository // T13 admin quota view
|
||||
billingCache service.BillingCache // T17/T18 缓存失效(PUT/POST 路径)
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new admin user handler
|
||||
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
|
||||
func NewUserHandler(
|
||||
adminService service.AdminService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
|
||||
billingCache service.BillingCache,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
adminService: adminService,
|
||||
concurrencyService: concurrencyService,
|
||||
adminService: adminService,
|
||||
concurrencyService: concurrencyService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
billingCache: billingCache,
|
||||
}
|
||||
}
|
||||
|
||||
@ -537,3 +551,294 @@ func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, gin.H{"affected": affected})
|
||||
}
|
||||
|
||||
// GetUserPlatformQuotas GET /admin/users/:id/platform-quotas
|
||||
// admin 视角:D14 lazy 归零 + 暴露 *_window_start 调试字段
|
||||
func (h *UserHandler) GetUserPlatformQuotas(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
userID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid user id")
|
||||
return
|
||||
}
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Success(c, map[string]any{"platform_quotas": []any{}})
|
||||
return
|
||||
}
|
||||
// 校验用户存在:与 PUT/POST 路径一致,不存在返回 404 而非空数组(避免 admin 界面误判用户存在)。
|
||||
if _, err := h.adminService.GetUser(c.Request.Context(), userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, true)) // true = 暴露 window_start
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// UpdateUserPlatformQuotasRequest is the body for PUT /admin/users/:id/platform-quotas.
|
||||
type UpdateUserPlatformQuotasRequest struct {
|
||||
Quotas []PlatformQuotaInput `json:"quotas" binding:"required"`
|
||||
}
|
||||
|
||||
// PlatformQuotaInput 单平台限额输入;limit 字段为 nil 表示不限制。
|
||||
type PlatformQuotaInput struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
}
|
||||
|
||||
// platform 合法性由 service.IsAllowedQuotaPlatform / service.AllowedQuotaPlatforms 统一判断(单一源)。
|
||||
|
||||
// UpdateUserPlatformQuotas PUT /admin/users/:id/platform-quotas
|
||||
// 全量替换该用户所有平台限额。
|
||||
func (h *UserHandler) UpdateUserPlatformQuotas(c *gin.Context) {
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Error(c, 503, "platform quota service not available")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserPlatformQuotasRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Quotas) > 4 {
|
||||
response.BadRequest(c, "quotas length must be <= 4")
|
||||
return
|
||||
}
|
||||
seen := make(map[string]struct{}, len(req.Quotas))
|
||||
for _, q := range req.Quotas {
|
||||
if !service.IsAllowedQuotaPlatform(q.Platform) {
|
||||
response.BadRequest(c, "invalid platform: "+q.Platform)
|
||||
return
|
||||
}
|
||||
if _, dup := seen[q.Platform]; dup {
|
||||
response.BadRequest(c, "duplicate platform: "+q.Platform)
|
||||
return
|
||||
}
|
||||
seen[q.Platform] = struct{}{}
|
||||
// daily_limit_usd / weekly_limit_usd / monthly_limit_usd 的语义:
|
||||
// nil / not set → 无限额(完全放行)
|
||||
// 0 → 完全禁用(任何请求都会被拒绝,因为 usage >= 0 恒成立)
|
||||
// > 0 → USD 限额上限
|
||||
// 拦截 NaN / ±Inf:客户端可发送超大数(如 1e308 × 2)使 JSON 反序列化得到 +Inf,
|
||||
// 进入 DB 后 cache check 中 usage >= limit 永不成立,limit 等同失效。
|
||||
for _, f := range []struct {
|
||||
name string
|
||||
val *float64
|
||||
}{
|
||||
{"daily_limit_usd", q.DailyLimitUSD},
|
||||
{"weekly_limit_usd", q.WeeklyLimitUSD},
|
||||
{"monthly_limit_usd", q.MonthlyLimitUSD},
|
||||
} {
|
||||
if f.val == nil {
|
||||
continue
|
||||
}
|
||||
v := *f.val
|
||||
if v < 0 {
|
||||
response.BadRequest(c, f.name+" must be >= 0")
|
||||
return
|
||||
}
|
||||
if math.IsNaN(v) || math.IsInf(v, 0) {
|
||||
response.BadRequest(c, f.name+" must be a finite number")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
records := make([]service.UserPlatformQuotaRecord, 0, len(req.Quotas))
|
||||
for _, q := range req.Quotas {
|
||||
records = append(records, service.UserPlatformQuotaRecord{
|
||||
UserID: userID,
|
||||
Platform: q.Platform,
|
||||
DailyLimitUSD: q.DailyLimitUSD,
|
||||
WeeklyLimitUSD: q.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: q.MonthlyLimitUSD,
|
||||
})
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
// 校验用户是否存在,避免 FK 违反导致 500;用户不存在时返回 404。
|
||||
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// 在 UpsertForUser 之前抓取 before snapshot 用于审计 before/after 对比。
|
||||
// ListByUser 失败不阻断主操作(best-effort),仅记录降级 warn。
|
||||
beforeRecords, beforeErr := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if beforeErr != nil {
|
||||
slog.Warn("quota audit before snapshot failed", "user_id", userID, "err", beforeErr)
|
||||
}
|
||||
if err := h.userPlatformQuotaRepo.UpsertForUser(ctx, userID, records); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
beforeByPlatform := make(map[string]service.UserPlatformQuotaRecord, len(beforeRecords))
|
||||
for _, r := range beforeRecords {
|
||||
beforeByPlatform[r.Platform] = r
|
||||
}
|
||||
afterPlatforms := make(map[string]struct{}, len(records))
|
||||
for _, r := range records {
|
||||
afterPlatforms[r.Platform] = struct{}{}
|
||||
}
|
||||
changes := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
entry := map[string]any{
|
||||
"platform": r.Platform,
|
||||
"daily_limit_usd": r.DailyLimitUSD,
|
||||
"weekly_limit_usd": r.WeeklyLimitUSD,
|
||||
"monthly_limit_usd": r.MonthlyLimitUSD,
|
||||
}
|
||||
if prev, ok := beforeByPlatform[r.Platform]; ok {
|
||||
entry["before_daily_limit_usd"] = prev.DailyLimitUSD
|
||||
entry["before_weekly_limit_usd"] = prev.WeeklyLimitUSD
|
||||
entry["before_monthly_limit_usd"] = prev.MonthlyLimitUSD
|
||||
}
|
||||
changes = append(changes, entry)
|
||||
}
|
||||
// 补 removed 条目:before 存在但 after 缺失 = 该平台被软删除。
|
||||
// 缺少这条记录,审计消费方无法察觉"管理员把某平台从配额列表移除"的操作(合规盲区)。
|
||||
for _, prev := range beforeRecords {
|
||||
if _, kept := afterPlatforms[prev.Platform]; kept {
|
||||
continue
|
||||
}
|
||||
changes = append(changes, map[string]any{
|
||||
"platform": prev.Platform,
|
||||
"removed": true,
|
||||
"before_daily_limit_usd": prev.DailyLimitUSD,
|
||||
"before_weekly_limit_usd": prev.WeeklyLimitUSD,
|
||||
"before_monthly_limit_usd": prev.MonthlyLimitUSD,
|
||||
})
|
||||
}
|
||||
// before_snapshot_available 让审计消费方能识别 changes 中是否带 before_* 字段;
|
||||
// false 时所有 entry 都会缺失 before_*_limit_usd,仅有 after 视图。
|
||||
slog.Info("admin.quota_updated",
|
||||
"actor_admin_id", getAdminIDFromContext(c),
|
||||
"target_user_id", userID,
|
||||
"platform_count", len(records),
|
||||
"before_snapshot_available", beforeErr == nil,
|
||||
"changes", changes)
|
||||
|
||||
// 失效 cache:对全部允许的 platform 统一 invalidate。
|
||||
// Trade-off:精确失效(仅 req 涉及平台 + 被软删平台)需 upsert 前额外 ListByUser,
|
||||
// 增加一次 DB 查询和逻辑复杂度。由于 AllowedQuotaPlatforms 只有 4 个元素,
|
||||
// 全量 invalidate 的额外开销可接受,且能可靠覆盖软删除场景。
|
||||
if h.billingCache != nil {
|
||||
for _, p := range service.AllowedQuotaPlatforms {
|
||||
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, p); err != nil {
|
||||
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", p, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回最新状态
|
||||
now := time.Now().UTC()
|
||||
records2, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]map[string]any, 0, len(records2))
|
||||
for i := range records2 {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(records2[i], now, true))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// ResetUserPlatformQuotaWindowRequest is the body for POST /admin/users/:id/platform-quotas/reset.
|
||||
type ResetUserPlatformQuotaWindowRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
Window string `json:"window" binding:"required"`
|
||||
}
|
||||
|
||||
var allowedWindowsForQuotaReset = map[string]struct{}{
|
||||
"daily": {},
|
||||
"weekly": {},
|
||||
"monthly": {},
|
||||
}
|
||||
|
||||
// ResetUserPlatformQuotaWindow POST /admin/users/:id/platform-quotas/reset
|
||||
// 立即归零指定 (platform, window) 的用量并更新 window_start。
|
||||
func (h *UserHandler) ResetUserPlatformQuotaWindow(c *gin.Context) {
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Error(c, 503, "platform quota service not available")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req ResetUserPlatformQuotaWindowRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsAllowedQuotaPlatform(req.Platform) {
|
||||
response.BadRequest(c, "invalid platform: "+req.Platform)
|
||||
return
|
||||
}
|
||||
if _, ok := allowedWindowsForQuotaReset[req.Window]; !ok {
|
||||
response.BadRequest(c, "invalid window: "+req.Window)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
// 校验用户是否存在,避免对不存在的用户执行操作返回误导性的 500。
|
||||
if _, err := h.adminService.GetUser(ctx, userID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := h.userPlatformQuotaRepo.ResetExpiredWindow(ctx, userID, req.Platform, req.Window, now); err != nil {
|
||||
if errors.Is(err, service.ErrUserPlatformQuotaNotFound) {
|
||||
response.NotFound(c, "user platform quota not found")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("admin.quota_window_reset",
|
||||
"actor_admin_id", getAdminIDFromContext(c),
|
||||
"target_user_id", userID,
|
||||
"platform", req.Platform,
|
||||
"window", req.Window)
|
||||
|
||||
if h.billingCache != nil {
|
||||
if err := h.billingCache.DeleteUserPlatformQuotaCache(ctx, userID, req.Platform); err != nil {
|
||||
slog.Warn("quota cache invalidation failed", "user_id", userID, "platform", req.Platform, "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(records[i], now, true))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
|
||||
UpdatedAt: lastLoginAt,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(adminSvc, nil)
|
||||
handler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -89,7 +89,7 @@ func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
|
||||
UpdatedAt: lastLoginAt,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(adminSvc, nil)
|
||||
handler := NewUserHandler(adminSvc, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
301
backend/internal/handler/admin/user_platform_quota_admin_test.go
Normal file
301
backend/internal/handler/admin/user_platform_quota_admin_test.go
Normal file
@ -0,0 +1,301 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// upsertCapturingQuotaRepo 实现 service.UserPlatformQuotaRepository,捕获 UpsertForUser 调用。
|
||||
type upsertCapturingQuotaRepo struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
listRecords []service.UserPlatformQuotaRecord
|
||||
listErr error
|
||||
upsertCalls []upsertCall
|
||||
upsertErr error
|
||||
resetCalls []resetCall
|
||||
resetErr error
|
||||
}
|
||||
|
||||
type upsertCall struct {
|
||||
userID int64
|
||||
records []service.UserPlatformQuotaRecord
|
||||
}
|
||||
type resetCall struct {
|
||||
userID int64
|
||||
platform string
|
||||
window string
|
||||
newStart time.Time
|
||||
}
|
||||
|
||||
func (r *upsertCapturingQuotaRepo) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return r.listRecords, r.listErr
|
||||
}
|
||||
func (r *upsertCapturingQuotaRepo) UpsertForUser(_ context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
|
||||
cloned := make([]service.UserPlatformQuotaRecord, len(records))
|
||||
copy(cloned, records)
|
||||
r.upsertCalls = append(r.upsertCalls, upsertCall{userID: userID, records: cloned})
|
||||
return r.upsertErr
|
||||
}
|
||||
func (r *upsertCapturingQuotaRepo) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
r.resetCalls = append(r.resetCalls, resetCall{userID, platform, window, newStart})
|
||||
return r.resetErr
|
||||
}
|
||||
|
||||
// billingCacheStub 实现 service.BillingCache 中本测试关心的 Delete 方法;其他方法 panic。
|
||||
type billingCacheStub struct {
|
||||
service.BillingCache
|
||||
deleteCalls []deleteCall
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type deleteCall struct {
|
||||
userID int64
|
||||
platform string
|
||||
}
|
||||
|
||||
func (b *billingCacheStub) DeleteUserPlatformQuotaCache(_ context.Context, userID int64, platform string) error {
|
||||
b.deleteCalls = append(b.deleteCalls, deleteCall{userID, platform})
|
||||
return b.deleteErr
|
||||
}
|
||||
|
||||
func buildTestHandler(repo service.UserPlatformQuotaRepository, cache service.BillingCache) *UserHandler {
|
||||
return &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: newStubAdminService(),
|
||||
}
|
||||
}
|
||||
|
||||
func putReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodPut, "/", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
c.Params = []gin.Param{{Key: "id", Value: "42"}}
|
||||
return c, w
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_Success(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic","daily_limit_usd":10.0,"weekly_limit_usd":null,"monthly_limit_usd":100.0},
|
||||
{"platform":"openai","daily_limit_usd":null,"weekly_limit_usd":null,"monthly_limit_usd":null}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(repo.upsertCalls) != 1 {
|
||||
t.Fatalf("UpsertForUser should be called once, got %d", len(repo.upsertCalls))
|
||||
}
|
||||
if repo.upsertCalls[0].userID != 42 || len(repo.upsertCalls[0].records) != 2 {
|
||||
t.Errorf("unexpected upsert call: %+v", repo.upsertCalls[0])
|
||||
}
|
||||
// 缓存失效:请求中 2 个 platform + 软删除的 2 个 platform(gemini, antigravity)= 4 次
|
||||
if len(cache.deleteCalls) != 4 {
|
||||
t.Errorf("expected 4 cache delete calls, got %d: %+v", len(cache.deleteCalls), cache.deleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsDuplicatePlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic","daily_limit_usd":1},
|
||||
{"platform":"anthropic","daily_limit_usd":2}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsInvalidPlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[{"platform":"unknown","daily_limit_usd":1}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsNegativeLimit(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":-1}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_RejectsTooManyEntries(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
body := `{"quotas":[
|
||||
{"platform":"anthropic"},{"platform":"openai"},{"platform":"gemini"},{"platform":"antigravity"},{"platform":"anthropic"}
|
||||
]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_ReturnsLatestState(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{
|
||||
listRecords: []service.UserPlatformQuotaRecord{
|
||||
{UserID: 42, Platform: "anthropic"},
|
||||
},
|
||||
}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if !strings.Contains(w.Body.String(), `"platform_quotas"`) {
|
||||
t.Errorf("response should contain platform_quotas array: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ───────── T4: Reset 测试 ─────────
|
||||
|
||||
func postReq(t *testing.T, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
c.Params = []gin.Param{{Key: "id", Value: "42"}}
|
||||
return c, w
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_Success(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
body := `{"platform":"anthropic","window":"daily"}`
|
||||
c, w := postReq(t, body)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(repo.resetCalls) != 1 {
|
||||
t.Fatalf("ResetExpiredWindow should be called once, got %d", len(repo.resetCalls))
|
||||
}
|
||||
if repo.resetCalls[0].userID != 42 ||
|
||||
repo.resetCalls[0].platform != "anthropic" ||
|
||||
repo.resetCalls[0].window != "daily" {
|
||||
t.Errorf("unexpected reset call: %+v", repo.resetCalls[0])
|
||||
}
|
||||
if len(cache.deleteCalls) != 1 ||
|
||||
cache.deleteCalls[0].userID != 42 ||
|
||||
cache.deleteCalls[0].platform != "anthropic" {
|
||||
t.Errorf("expected 1 cache delete for anthropic, got %+v", cache.deleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_RejectsInvalidWindow(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"yearly"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_RejectsInvalidPlatform(t *testing.T) {
|
||||
h := buildTestHandler(&upsertCapturingQuotaRepo{}, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"unknown","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_NotFound(t *testing.T) {
|
||||
// handler 检查 service.ErrUserPlatformQuotaNotFound(由 adapter 包装而来)
|
||||
repo := &upsertCapturingQuotaRepo{resetErr: service.ErrUserPlatformQuotaNotFound}
|
||||
h := buildTestHandler(repo, &billingCacheStub{})
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_JSONErrorOnRepoFailure(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{upsertErr: errors.New("db down")}
|
||||
cache := &billingCacheStub{}
|
||||
h := buildTestHandler(repo, cache)
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code < 500 {
|
||||
t.Errorf("expected 5xx, got %d", w.Code)
|
||||
}
|
||||
// 返回 JSON 错误响应
|
||||
var body2 map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body2); err != nil {
|
||||
t.Errorf("expected JSON error body, got: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserPlatformQuotas_UserNotFound(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = service.ErrUserNotFound
|
||||
h := &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: adminSvc,
|
||||
}
|
||||
body := `{"quotas":[{"platform":"anthropic","daily_limit_usd":10}]}`
|
||||
c, w := putReq(t, body)
|
||||
h.UpdateUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetUserPlatformQuotaWindow_UserNotFound(t *testing.T) {
|
||||
repo := &upsertCapturingQuotaRepo{}
|
||||
cache := &billingCacheStub{}
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = service.ErrUserNotFound
|
||||
h := &UserHandler{
|
||||
userPlatformQuotaRepo: repo,
|
||||
billingCache: cache,
|
||||
adminService: adminSvc,
|
||||
}
|
||||
c, w := postReq(t, `{"platform":"anthropic","window":"daily"}`)
|
||||
h.ResetUserPlatformQuotaWindow(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 when user not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,124 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type fakeQuotaRepoForAdmin struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
records []service.UserPlatformQuotaRecord
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepoForAdmin) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return f.records, f.err
|
||||
}
|
||||
|
||||
func newAdminQuotaTestContext(w *httptest.ResponseRecorder) *gin.Context {
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request = req
|
||||
return c
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_IncludesWindowStart(t *testing.T) {
|
||||
start := time.Now().Add(-1 * time.Hour)
|
||||
repo := &fakeQuotaRepoForAdmin{records: []service.UserPlatformQuotaRecord{{
|
||||
UserID: 99, Platform: "anthropic",
|
||||
DailyUsageUSD: 1.0, DailyWindowStart: &start,
|
||||
}}}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "99"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), `"daily_window_start"`) {
|
||||
t.Errorf("admin response missing daily_window_start, got: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_InvalidIDReturns400(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: &fakeQuotaRepoForAdmin{}}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "abc"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code < 400 || w.Code >= 500 {
|
||||
t.Errorf("invalid id should yield 4xx, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_EmptyReturnsEmptyArray(t *testing.T) {
|
||||
repo := &fakeQuotaRepoForAdmin{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: newStubAdminService()}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "99"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("empty list should be 200, got %d", w.Code)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
}
|
||||
data, ok := body["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response missing data object: %v", body)
|
||||
}
|
||||
quotas, ok := data["platform_quotas"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("data.platform_quotas missing or wrong type: %v", data)
|
||||
}
|
||||
if len(quotas) != 0 {
|
||||
t.Errorf("expected empty platform_quotas, got %d entries: %v", len(quotas), quotas)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetUserPlatformQuotas_NilRepoReturnsEmpty(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "1"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("nil repo should return 200 empty, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminGetUserPlatformQuotas_UserNotFoundReturns404 验证 GET 在用户不存在时返回 404
|
||||
// (与 PUT / POST reset 端点行为一致;review fix:原实现返回空数组会让 admin 界面误判用户存在)
|
||||
func TestAdminGetUserPlatformQuotas_UserNotFoundReturns404(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.getUserErr = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||||
repo := &fakeQuotaRepoForAdmin{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo, adminService: adminSvc}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c := newAdminQuotaTestContext(w)
|
||||
c.Params = []gin.Param{{Key: "id", Value: "999"}}
|
||||
h.GetUserPlatformQuotas(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for non-existent user, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@ -2233,6 +2233,7 @@ CREATE TABLE IF NOT EXISTS user_affiliates (
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
affiliateService,
|
||||
nil,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
|
||||
@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := &AuthHandler{authService: authService}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@ -1400,6 +1400,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
|
||||
@ -3,6 +3,8 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// CustomMenuItem represents a user-configured custom menu entry.
|
||||
@ -246,6 +248,9 @@ type SystemSettings struct {
|
||||
|
||||
// OpenAI fast/flex policy
|
||||
OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
|
||||
|
||||
// 系统全局默认平台配额(key = platform,nil/缺省 = 不限制)
|
||||
DefaultPlatformQuotas map[string]*service.DefaultPlatformQuotaSetting `json:"default_platform_quotas,omitempty"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -250,7 +251,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -508,10 +509,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
@ -808,7 +811,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil, service.PlatformFromAPIKey(fallbackAPIKey)); err != nil {
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
@ -900,10 +903,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), currentAPIKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ParsedRequest: parsedReq,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
@ -1582,7 +1587,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
@ -1830,6 +1835,36 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// extractQuotaResetSeconds 从 quota 错误的 metadata 中提取 window_resets_at 并计算
|
||||
// 距重置剩余秒数。fallback 路径必须返回 ≥1 秒,避免客户端立即重试无限循环。
|
||||
func extractQuotaResetSeconds(err error) int {
|
||||
const fallback = 60
|
||||
appErr := pkgerrors.FromError(err)
|
||||
if appErr == nil {
|
||||
return fallback
|
||||
}
|
||||
raw, ok := appErr.Metadata["window_resets_at"]
|
||||
if !ok || raw == "" {
|
||||
return fallback
|
||||
}
|
||||
resetAt, parseErr := time.Parse(time.RFC3339, raw)
|
||||
if parseErr != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.billing"),
|
||||
zap.String("raw", raw),
|
||||
zap.Error(parseErr),
|
||||
).Warn("quota.invalid_window_resets_at_format")
|
||||
return fallback
|
||||
}
|
||||
secs := time.Until(resetAt).Seconds()
|
||||
if secs <= 0 {
|
||||
// reset 时间已过:cache 与 DB 应该正在自愈,返回 fallback 让客户端按常规节奏退避,
|
||||
// 避免返回 1 秒导致客户端立即重试仍触发限额的退避循环。
|
||||
return fallback
|
||||
}
|
||||
return int(math.Ceil(secs))
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
|
||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||
msg := pkgerrors.Message(err)
|
||||
@ -1857,6 +1892,14 @@ func billingErrorDetails(err error) (status int, code, message string, retryAfte
|
||||
retrySeconds := 60 - int(time.Now().Unix()%60)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
|
||||
}
|
||||
if errors.Is(err, service.ErrUserPlatformDailyQuotaExhausted) ||
|
||||
errors.Is(err, service.ErrUserPlatformWeeklyQuotaExhausted) ||
|
||||
errors.Is(err, service.ErrUserPlatformMonthlyQuotaExhausted) {
|
||||
// 与 RPM 超限一致映射 429 + Retry-After,让 SDK 自动退避(而非 403 直接失败)。
|
||||
// 错误码用 rate_limit_exceeded 与 OpenAI 兼容客户端一致;细分类型由 ErrCode + window_resets_at metadata 区分。
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, extractQuotaResetSeconds(err)
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
logger.L().With(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -52,3 +54,75 @@ func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
|
||||
require.Equal(t, "billing_error", code)
|
||||
require.NotEmpty(t, msg)
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T19_HappyPath(t *testing.T) {
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(10 * time.Second).UTC().Format(time.RFC3339),
|
||||
})
|
||||
got := extractQuotaResetSeconds(err)
|
||||
if got < 10 || got > 11 {
|
||||
t.Errorf("T19: got %d, want 10 or 11 (math.Ceil boundary)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T20_NoMetadataFallback(t *testing.T) {
|
||||
if got := extractQuotaResetSeconds(errors.New("naked error")); got != 60 {
|
||||
t.Errorf("T20: got %d, want 60 fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T21_BadFormatFallback(t *testing.T) {
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": "not-a-time",
|
||||
})
|
||||
if got := extractQuotaResetSeconds(err); got != 60 {
|
||||
t.Errorf("T21: got %d, want 60 fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractQuotaResetSeconds_T22_PastResetFallsBackToDefault(t *testing.T) {
|
||||
// 当 window_resets_at 已过去时返回 fallback (60s) 而非 1s:
|
||||
// 1 秒会导致客户端立即重试仍触发限额的退避循环;
|
||||
// 60s 让客户端按常规节奏退避,cache/DB 自愈期间不会反复打抖。
|
||||
err := service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(-5 * time.Second).UTC().Format(time.RFC3339),
|
||||
})
|
||||
if got := extractQuotaResetSeconds(err); got != 60 {
|
||||
t.Errorf("T22: got %d, want 60 (fallback on past reset)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_T10_QuotaExhaustedReturns429WithRetryAfter(t *testing.T) {
|
||||
// quota 超限映射 429 + Retry-After(RFC 6585 / 与 RPM 一致),
|
||||
// 让 SDK(OpenAI 兼容客户端等)能按 Retry-After 自动退避。
|
||||
// 旧实现用 403 导致客户端不退避直接报错。
|
||||
// 三个窗口共用同一映射分支,循环覆盖避免漏测某个窗口的 status/code。
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"daily", service.ErrUserPlatformDailyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
{"weekly", service.ErrUserPlatformWeeklyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
{"monthly", service.ErrUserPlatformMonthlyQuotaExhausted.WithMetadata(map[string]string{
|
||||
"window_resets_at": time.Now().Add(60 * time.Minute).UTC().Format(time.RFC3339),
|
||||
})},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
status, code, _, retryAfter := billingErrorDetails(tc.err)
|
||||
if status != http.StatusTooManyRequests {
|
||||
t.Errorf("status = %d, want 429", status)
|
||||
}
|
||||
if code != "rate_limit_exceeded" {
|
||||
t.Errorf("code = %q, want rate_limit_exceeded", code)
|
||||
}
|
||||
if retryAfter < 3599 || retryAfter > 3601 {
|
||||
t.Errorf("retryAfter = %d, want ~3600", retryAfter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -291,9 +291,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -145,7 +145,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -266,9 +266,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -172,11 +172,12 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // channelService
|
||||
nil, // resolver
|
||||
nil, // balanceNotifyService
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
@ -43,7 +43,7 @@ func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHand
|
||||
gatewayService: service.NewGatewayService(
|
||||
repo,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,7 +247,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -527,9 +527,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
inboundEndpoint := GetInboundEndpoint(c)
|
||||
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
||||
quotaPlatform := service.QuotaPlatform(c.Request.Context(), apiKey)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
QuotaPlatform: quotaPlatform,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
|
||||
@ -206,7 +206,7 @@ func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
|
||||
billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}, nil),
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
|
||||
|
||||
@ -106,7 +106,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
|
||||
@ -243,7 +243,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -648,7 +648,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
@ -1235,7 +1235,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||
return
|
||||
|
||||
@ -1201,7 +1201,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
|
||||
}, nil, nil, nil)
|
||||
}
|
||||
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
gatewaySvc := service.NewOpenAIGatewayService(
|
||||
accountRepo,
|
||||
usageRepo,
|
||||
@ -1223,6 +1223,7 @@ func runOpenAIResponsesWebSocketUsageLogCase(t *testing.T, tc openAIResponsesWSU
|
||||
channelSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
|
||||
@ -123,7 +123,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription, service.QuotaPlatform(c.Request.Context(), apiKey)); err != nil {
|
||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
|
||||
104
backend/internal/handler/quotaview/helpers.go
Normal file
104
backend/internal/handler/quotaview/helpers.go
Normal file
@ -0,0 +1,104 @@
|
||||
// Package quotaview provides shared quota response helpers for user and admin handlers.
|
||||
// Extracted to avoid import cycles between handler and handler/admin packages.
|
||||
package quotaview
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// LazyZeroQuotaForResponse 按 D14 规则把过期档位归零(不写 DB)。
|
||||
// includeWindowStart=true 时输出 *_window_start 字段(admin 视角调试用)
|
||||
func LazyZeroQuotaForResponse(r service.UserPlatformQuotaRecord, now time.Time, includeWindowStart bool) map[string]any {
|
||||
daily := buildWindowSlice(r.DailyUsageUSD, r.DailyLimitUSD, r.DailyWindowStart, NeedsDailyReset(r.DailyWindowStart, now), nextDailyResetTime(now), includeWindowStart)
|
||||
weekly := buildWindowSlice(r.WeeklyUsageUSD, r.WeeklyLimitUSD, r.WeeklyWindowStart, NeedsWeeklyReset(r.WeeklyWindowStart, now), nextWeeklyResetTime(now), includeWindowStart)
|
||||
monthly := buildWindowSlice(r.MonthlyUsageUSD, r.MonthlyLimitUSD, r.MonthlyWindowStart, NeedsMonthlyReset(r.MonthlyWindowStart, now), NextMonthlyResetTimeFrom(r.MonthlyWindowStart, now), includeWindowStart)
|
||||
out := map[string]any{
|
||||
"platform": r.Platform,
|
||||
"daily_usage_usd": daily.usage,
|
||||
"daily_limit_usd": daily.limit,
|
||||
"daily_window_resets_at": daily.resetsAt,
|
||||
"weekly_usage_usd": weekly.usage,
|
||||
"weekly_limit_usd": weekly.limit,
|
||||
"weekly_window_resets_at": weekly.resetsAt,
|
||||
"monthly_usage_usd": monthly.usage,
|
||||
"monthly_limit_usd": monthly.limit,
|
||||
"monthly_window_resets_at": monthly.resetsAt,
|
||||
}
|
||||
if includeWindowStart {
|
||||
out["daily_window_start"] = daily.windowStart
|
||||
out["weekly_window_start"] = weekly.windowStart
|
||||
out["monthly_window_start"] = monthly.windowStart
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type windowSlice struct {
|
||||
usage float64
|
||||
limit *float64
|
||||
resetsAt *string
|
||||
windowStart *string
|
||||
}
|
||||
|
||||
func buildWindowSlice(usage float64, limit *float64, start *time.Time, expired bool, nextReset time.Time, includeStart bool) windowSlice {
|
||||
out := windowSlice{usage: usage, limit: limit}
|
||||
if expired {
|
||||
out.usage = 0
|
||||
out.resetsAt = nil
|
||||
} else if start != nil {
|
||||
s := nextReset.Format(time.RFC3339)
|
||||
out.resetsAt = &s
|
||||
}
|
||||
if includeStart && start != nil {
|
||||
s := start.Format(time.RFC3339)
|
||||
out.windowStart = &s
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NeedsDailyReset 判断日窗口是否已过期:start 早于「全局时区当天 0 点」即过期。
|
||||
// 时区跟随 timezone.Location()(全局服务器时区),与 billing / repo 写入的 window_start 同口径。
|
||||
func NeedsDailyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return start.Before(timezone.StartOfDay(now))
|
||||
}
|
||||
|
||||
func NeedsWeeklyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return start.Before(timezone.StartOfWeek(now))
|
||||
}
|
||||
|
||||
// NeedsMonthlyReset 30 天滚动窗口语义(与订阅模式 NeedsMonthlyReset 一致)。
|
||||
func NeedsMonthlyReset(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return false
|
||||
}
|
||||
return now.Sub(*start) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
func nextDailyResetTime(now time.Time) time.Time {
|
||||
return timezone.StartOfDay(now).AddDate(0, 0, 1)
|
||||
}
|
||||
|
||||
func nextWeeklyResetTime(now time.Time) time.Time {
|
||||
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
|
||||
}
|
||||
|
||||
// NextMonthlyResetTimeFrom 计算 30 天滚动月度窗口的下次重置时间。
|
||||
// 语义:
|
||||
// - start != nil → 返回 start + 30d(与 billing_cache_service.nextMonthlyResetFrom 一致)
|
||||
// - start == nil → 退化为 now + 30d(保留旧行为,避免 nil 崩溃)
|
||||
//
|
||||
// 导出(首字母大写)以允许测试直接调用。
|
||||
func NextMonthlyResetTimeFrom(start *time.Time, now time.Time) time.Time {
|
||||
if start == nil {
|
||||
return now.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
return start.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
133
backend/internal/handler/quotaview/helpers_test.go
Normal file
133
backend/internal/handler/quotaview/helpers_test.go
Normal file
@ -0,0 +1,133 @@
|
||||
package quotaview
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TestNextMonthlyResetTimeFrom_FromStart 验证:start 已知时返回 start+30d,不随 now 漂移。
|
||||
func TestNextMonthlyResetTimeFrom_FromStart(t *testing.T) {
|
||||
t0 := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
now := t0.Add(15 * 24 * time.Hour) // t0 + 15d
|
||||
want := t0.Add(30 * 24 * time.Hour) // t0 + 30d
|
||||
|
||||
got := NextMonthlyResetTimeFrom(&t0, now)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("NextMonthlyResetTimeFrom: want %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextMonthlyResetTimeFrom_NilStart 验证:start=nil 时退化为 now+30d(不 panic)。
|
||||
func TestNextMonthlyResetTimeFrom_NilStart(t *testing.T) {
|
||||
now := time.Date(2024, 3, 15, 12, 0, 0, 0, time.UTC)
|
||||
want := now.Add(30 * 24 * time.Hour)
|
||||
|
||||
got := NextMonthlyResetTimeFrom(nil, now)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("NextMonthlyResetTimeFrom(nil): want %v, got %v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting 验证:
|
||||
// 连续两次以不同 now 调用、但 MonthlyWindowStart 相同的 record,
|
||||
// monthly_window_resets_at 始终等于 windowStart+30d,不随 now 漂移。
|
||||
func TestLazyZeroQuotaForResponse_MonthlyResetsAt_NotDrifting(t *testing.T) {
|
||||
windowStart := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
wantResetsAt := windowStart.Add(30 * 24 * time.Hour).Format(time.RFC3339)
|
||||
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "openai",
|
||||
MonthlyUsageUSD: 5.0,
|
||||
MonthlyWindowStart: &windowStart,
|
||||
}
|
||||
|
||||
// 第一次调用:now = windowStart + 5d
|
||||
now1 := windowStart.Add(5 * 24 * time.Hour)
|
||||
out1 := LazyZeroQuotaForResponse(r, now1, false)
|
||||
resetsAt1, ok1 := out1["monthly_window_resets_at"]
|
||||
if !ok1 || resetsAt1 == nil {
|
||||
t.Fatal("first call: monthly_window_resets_at should be set for active window")
|
||||
}
|
||||
s1, ok := resetsAt1.(*string)
|
||||
if !ok || s1 == nil {
|
||||
t.Fatalf("first call: monthly_window_resets_at should be *string, got %T", resetsAt1)
|
||||
}
|
||||
if *s1 != wantResetsAt {
|
||||
t.Errorf("first call: want %s, got %s", wantResetsAt, *s1)
|
||||
}
|
||||
|
||||
// 第二次调用:now = windowStart + 10d(不同 now,但 resetsAt 应不变)
|
||||
now2 := windowStart.Add(10 * 24 * time.Hour)
|
||||
out2 := LazyZeroQuotaForResponse(r, now2, false)
|
||||
resetsAt2, ok2 := out2["monthly_window_resets_at"]
|
||||
if !ok2 || resetsAt2 == nil {
|
||||
t.Fatal("second call: monthly_window_resets_at should be set for active window")
|
||||
}
|
||||
s2, ok := resetsAt2.(*string)
|
||||
if !ok || s2 == nil {
|
||||
t.Fatalf("second call: monthly_window_resets_at should be *string, got %T", resetsAt2)
|
||||
}
|
||||
if *s2 != wantResetsAt {
|
||||
t.Errorf("second call: want %s, got %s", wantResetsAt, *s2)
|
||||
}
|
||||
|
||||
// 两次结果必须相等
|
||||
if *s1 != *s2 {
|
||||
t.Errorf("resetsAt drifted between calls: %s vs %s", *s1, *s2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsDailyReset_FollowsServerTimezone 验证日窗口过期判断按全局时区(北京 0 点)而非 UTC。
|
||||
func TestNeedsDailyReset_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
// now = 2026-05-25 23:00 UTC = 2026-05-26 07:00 +08(北京 5/26)
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC)
|
||||
|
||||
// start = 2026-05-25 10:00 UTC = 2026-05-25 18:00 +08(北京 5/25)→ 应判定为过期
|
||||
startPrevBeijingDay := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC)
|
||||
if !NeedsDailyReset(&startPrevBeijingDay, now) {
|
||||
t.Error("上一个北京日的窗口应判定为过期")
|
||||
}
|
||||
|
||||
// start = 2026-05-25 20:00 UTC = 2026-05-26 04:00 +08(北京 5/26 同日)→ 不应过期
|
||||
startSameBeijingDay := time.Date(2026, 5, 25, 20, 0, 0, 0, time.UTC)
|
||||
if NeedsDailyReset(&startSameBeijingDay, now) {
|
||||
t.Error("同一北京日的窗口不应判定为过期")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextDailyResetTime_FollowsServerTimezone 验证下次日重置 = 次日北京 0 点。
|
||||
func TestNextDailyResetTime_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00
|
||||
want := time.Date(2026, 5, 27, 0, 0, 0, 0, timezone.Location()) // 北京 5/27 00:00
|
||||
if got := nextDailyResetTime(now); !got.Equal(want) {
|
||||
t.Errorf("nextDailyResetTime = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextWeeklyResetTime_FollowsServerTimezone 验证下次周重置 = 下周一北京 0 点。
|
||||
func TestNextWeeklyResetTime_FollowsServerTimezone(t *testing.T) {
|
||||
if err := timezone.Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = timezone.Init("UTC") })
|
||||
|
||||
// 北京 2026-05-26(周二)→ 下周一是 2026-06-01
|
||||
now := time.Date(2026, 5, 25, 23, 0, 0, 0, time.UTC) // 北京 5/26 07:00 周二
|
||||
want := time.Date(2026, 6, 1, 0, 0, 0, 0, timezone.Location())
|
||||
if got := nextWeeklyResetTime(now); !got.Equal(want) {
|
||||
t.Errorf("nextWeeklyResetTime = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@ -3,8 +3,10 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@ -14,11 +16,12 @@ import (
|
||||
|
||||
// UserHandler handles user-related requests
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
affiliateService *service.AffiliateService
|
||||
userService *service.UserService
|
||||
authService *service.AuthService
|
||||
emailService *service.EmailService
|
||||
emailCache service.EmailCache
|
||||
affiliateService *service.AffiliateService
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
// NewUserHandler creates a new UserHandler
|
||||
@ -28,16 +31,44 @@ func NewUserHandler(
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
affiliateService *service.AffiliateService,
|
||||
userPlatformQuotaRepo service.UserPlatformQuotaRepository,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetMyPlatformQuotas GET /user/platform-quotas
|
||||
// 返回当前 JWT 用户的 platform quota 状态。
|
||||
// D14: 对每条记录逐档判断窗口过期,过期档位 usage=0、window_resets_at=null(不写 DB)
|
||||
func (h *UserHandler) GetMyPlatformQuotas(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
if h.userPlatformQuotaRepo == nil {
|
||||
response.Success(c, map[string]any{"platform_quotas": []any{}})
|
||||
return
|
||||
}
|
||||
records, err := h.userPlatformQuotaRepo.ListByUser(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
out := make([]map[string]any, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, quotaview.LazyZeroQuotaForResponse(r, now, false))
|
||||
}
|
||||
response.Success(c, map[string]any{"platform_quotas": out})
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
|
||||
@ -87,8 +87,12 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina
|
||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
@ -144,7 +148,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -202,7 +206,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -272,20 +276,20 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
AvatarURL: "https://cdn.example.com/linuxdo.png",
|
||||
AvatarSource: "remote_url",
|
||||
},
|
||||
identities: []service.UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-21",
|
||||
VerifiedAt: &verifiedAt,
|
||||
Metadata: map[string]any{
|
||||
"username": "linuxdo-handle",
|
||||
"avatar_url": "https://cdn.example.com/linuxdo.png",
|
||||
},
|
||||
identities: []service.UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-21",
|
||||
VerifiedAt: &verifiedAt,
|
||||
Metadata: map[string]any{
|
||||
"username": "linuxdo-handle",
|
||||
"avatar_url": "https://cdn.example.com/linuxdo.png",
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -364,7 +368,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -513,8 +517,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -568,7 +572,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -627,8 +631,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -670,8 +674,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@ -714,8 +718,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@ -752,7 +756,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
212
backend/internal/handler/user_platform_quotas_handler_test.go
Normal file
212
backend/internal/handler/user_platform_quotas_handler_test.go
Normal file
@ -0,0 +1,212 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/quotaview"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// fakeQuotaRepoForUserHandler 实现 service.UserPlatformQuotaRepository 最小子集
|
||||
type fakeQuotaRepoForUserHandler struct {
|
||||
service.UserPlatformQuotaRepository
|
||||
records []service.UserPlatformQuotaRecord
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepoForUserHandler) ListByUser(_ context.Context, _ int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
return f.records, nil
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_EmptyReturns200WithEmptyArray(t *testing.T) {
|
||||
repo := &fakeQuotaRepoForUserHandler{records: nil}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var body struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
PlatformQuotas []any `json:"platform_quotas"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal error: %v, body: %s", err, w.Body.String())
|
||||
}
|
||||
if body.Code != 0 {
|
||||
t.Errorf("expected code=0, got %d", body.Code)
|
||||
}
|
||||
if body.Data.PlatformQuotas == nil {
|
||||
// nil 和 empty slice 均视为可接受(JSON 可能序列化为 null 或 [])
|
||||
// 此断言只验证 HTTP 200 + code=0 即可
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_D14_LazyZeroForExpiredWindow(t *testing.T) {
|
||||
pastStart := time.Now().UTC().AddDate(0, 0, -2)
|
||||
daily := 5.0
|
||||
repo := &fakeQuotaRepoForUserHandler{records: []service.UserPlatformQuotaRecord{{
|
||||
UserID: 42,
|
||||
Platform: "anthropic",
|
||||
DailyLimitUSD: &daily,
|
||||
DailyUsageUSD: 3.0,
|
||||
DailyWindowStart: &pastStart,
|
||||
}}}
|
||||
h := &UserHandler{userPlatformQuotaRepo: repo}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d. body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 解析 response,验证过期 daily 的 usage_usd=0 且 window_resets_at=null
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"daily_usage_usd":0`) {
|
||||
t.Errorf("expected daily_usage_usd:0 in body, got: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, `"daily_window_resets_at":null`) {
|
||||
t.Errorf("expected daily_window_resets_at:null in body, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_NilRepo_Returns200Empty(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 99})
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMyPlatformQuotas_NoAuth_Returns401(t *testing.T) {
|
||||
h := &UserHandler{userPlatformQuotaRepo: nil}
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/user/platform-quotas", nil)
|
||||
// 不设置 auth subject
|
||||
h.GetMyPlatformQuotas(c)
|
||||
if w.Code != 401 {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_UserViewStripsWindowStart(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-1 * time.Hour)
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "anthropic",
|
||||
DailyUsageUSD: 1.0,
|
||||
DailyWindowStart: &start,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), false)
|
||||
if _, ok := out["daily_window_start"]; ok {
|
||||
t.Error("user view should not include daily_window_start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_AdminViewIncludesWindowStart(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-1 * time.Hour)
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "anthropic",
|
||||
DailyWindowStart: &start,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, time.Now().UTC(), true)
|
||||
if _, ok := out["daily_window_start"]; !ok {
|
||||
t.Error("admin view should include daily_window_start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLazyZeroQuotaForResponse_ActiveWindowPreservesUsage(t *testing.T) {
|
||||
// 今天的窗口起始时间(不过期):按全局时区取当天 0 点,与 view 层同口径
|
||||
now := time.Now()
|
||||
today := timezone.StartOfDay(now)
|
||||
usage := 2.5
|
||||
r := service.UserPlatformQuotaRecord{
|
||||
Platform: "openai",
|
||||
DailyUsageUSD: usage,
|
||||
DailyWindowStart: &today,
|
||||
}
|
||||
out := quotaview.LazyZeroQuotaForResponse(r, now, false)
|
||||
if out["daily_usage_usd"] != usage {
|
||||
t.Errorf("expected daily_usage_usd=%v, got %v", usage, out["daily_usage_usd"])
|
||||
}
|
||||
// 活跃窗口应有 resets_at(非 nil)
|
||||
if out["daily_window_resets_at"] == nil {
|
||||
t.Error("active window should have daily_window_resets_at set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDailyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsDailyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDailyReset_OldStart_ReturnsTrue(t *testing.T) {
|
||||
old := time.Now().UTC().AddDate(0, 0, -1)
|
||||
if !quotaview.NeedsDailyReset(&old, time.Now().UTC()) {
|
||||
t.Error("yesterday start should need daily reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsWeeklyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsWeeklyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need weekly reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsMonthlyReset_NilStart_ReturnsFalse(t *testing.T) {
|
||||
if quotaview.NeedsMonthlyReset(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should not need monthly reset")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsMonthlyReset_30DayRolling 验证 30 天滚动语义(C-NEW-1)。
|
||||
func TestNeedsMonthlyReset_30DayRolling_Expired(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-31 * 24 * time.Hour) // 31 天前,已过期
|
||||
if !quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
|
||||
t.Error("31 days ago should need monthly reset (30-day rolling)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsMonthlyReset_30DayRolling_Active(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-15 * 24 * time.Hour) // 15 天前,窗口有效
|
||||
if quotaview.NeedsMonthlyReset(&start, time.Now().UTC()) {
|
||||
t.Error("15 days ago should NOT need monthly reset (30-day rolling, still active)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedsMonthlyReset_CrossMonthBoundary 验证跨自然月时 30 天未满不重置(旧自然月语义会提前重置)。
|
||||
func TestNeedsMonthlyReset_CrossMonthBoundary(t *testing.T) {
|
||||
// 窗口起始 4 月 20 日;5 月 1 日仅过了 11 天,不足 30 天,不应重置
|
||||
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
|
||||
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
if quotaview.NeedsMonthlyReset(&windowStart, now) {
|
||||
t.Error("cross-month boundary within 30 days should NOT trigger reset (30-day rolling)")
|
||||
}
|
||||
}
|
||||
@ -135,3 +135,29 @@ func TestDSTAwareness(t *testing.T) {
|
||||
_ = Now()
|
||||
_ = StartOfDay(Now())
|
||||
}
|
||||
|
||||
func TestStartOfWeek_Boundaries(t *testing.T) {
|
||||
if err := Init("Asia/Shanghai"); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = Init("UTC") })
|
||||
|
||||
loc := Location()
|
||||
wantMon := time.Date(2026, 5, 18, 0, 0, 0, 0, loc) // 2026-05-18 是周一
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
in time.Time
|
||||
}{
|
||||
{"friday", time.Date(2026, 5, 22, 14, 30, 0, 0, loc)},
|
||||
{"sunday", time.Date(2026, 5, 24, 10, 0, 0, 0, loc)},
|
||||
{"monday-self", time.Date(2026, 5, 18, 9, 15, 30, 0, loc)},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
if got := StartOfWeek(c.in); !got.Equal(wantMon) {
|
||||
t.Errorf("StartOfWeek(%v) = %v, want %v", c.in, got, wantMon)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -328,3 +328,174 @@ func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int6
|
||||
key := billingRateLimitKey(keyID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// user × platform quota 缓存
|
||||
// ============================================
|
||||
|
||||
// userPlatformQuotaCacheKey 构造 Redis key
|
||||
func userPlatformQuotaCacheKey(userID int64, platform string) string {
|
||||
return fmt.Sprintf("billing:user_platform_quota:%d:%s", userID, platform)
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaCacheEntry, bool, error) {
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
fields := []string{
|
||||
"daily_usage", "weekly_usage", "monthly_usage", "version", "schema_version",
|
||||
"daily_limit", "weekly_limit", "monthly_limit",
|
||||
"daily_window_start", "weekly_window_start", "monthly_window_start",
|
||||
}
|
||||
vals, err := c.rdb.HMGet(ctx, key, fields...).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
// 前4个全为nil → key 不存在
|
||||
if vals[0] == nil && vals[1] == nil && vals[2] == nil && vals[3] == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
parseFloat := func(v any) float64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
f, _ := strconv.ParseFloat(s, 64)
|
||||
return f
|
||||
}
|
||||
parseFloatPtr := func(v any) *float64 {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return nil
|
||||
}
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &f
|
||||
}
|
||||
parseTimePtr := func(v any) *time.Time {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return nil
|
||||
}
|
||||
n, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
t := time.Unix(n, 0).UTC()
|
||||
return &t
|
||||
}
|
||||
parseInt64 := func(v any) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
n, _ := strconv.ParseInt(s, 10, 64)
|
||||
return n
|
||||
}
|
||||
return &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: parseFloat(vals[0]),
|
||||
WeeklyUsageUSD: parseFloat(vals[1]),
|
||||
MonthlyUsageUSD: parseFloat(vals[2]),
|
||||
Version: parseInt64(vals[3]),
|
||||
SchemaVersion: parseInt64(vals[4]),
|
||||
DailyLimitUSD: parseFloatPtr(vals[5]),
|
||||
WeeklyLimitUSD: parseFloatPtr(vals[6]),
|
||||
MonthlyLimitUSD: parseFloatPtr(vals[7]),
|
||||
DailyWindowStart: parseTimePtr(vals[8]),
|
||||
WeeklyWindowStart: parseTimePtr(vals[9]),
|
||||
MonthlyWindowStart: parseTimePtr(vals[10]),
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *service.UserPlatformQuotaCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
pipe := c.rdb.TxPipeline()
|
||||
|
||||
// 浮点可空字段:nil → 空字符串(读取时 parseFloatPtr 返回 nil,表示无限额)
|
||||
fmtFloatPtr := func(p *float64) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatFloat(*p, 'f', -1, 64)
|
||||
}
|
||||
// time.Time 可空字段:nil → 空字符串;有值 → unix 秒
|
||||
fmtTimePtr := func(p *time.Time) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatInt(p.Unix(), 10)
|
||||
}
|
||||
|
||||
pipe.HSet(ctx, key,
|
||||
"daily_usage", entry.DailyUsageUSD,
|
||||
"weekly_usage", entry.WeeklyUsageUSD,
|
||||
"monthly_usage", entry.MonthlyUsageUSD,
|
||||
"version", entry.Version,
|
||||
"schema_version", entry.SchemaVersion,
|
||||
"daily_limit", fmtFloatPtr(entry.DailyLimitUSD),
|
||||
"weekly_limit", fmtFloatPtr(entry.WeeklyLimitUSD),
|
||||
"monthly_limit", fmtFloatPtr(entry.MonthlyLimitUSD),
|
||||
"daily_window_start", fmtTimePtr(entry.DailyWindowStart),
|
||||
"weekly_window_start", fmtTimePtr(entry.WeeklyWindowStart),
|
||||
"monthly_window_start", fmtTimePtr(entry.MonthlyWindowStart),
|
||||
)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
|
||||
return c.rdb.Del(ctx, userPlatformQuotaCacheKey(userID, platform)).Err()
|
||||
}
|
||||
|
||||
// updateUserPlatformQuotaUsageScript 缓存累加:EXISTS + schema_version 双重守卫。
|
||||
// 旧版 entry(schema_version != ARGV[3],包括缺字段的 0 值)不参与累加,由上层走 DB fallback 后
|
||||
// SetCache 重建为新版 entry —— 若此处仍累加,上层覆盖时会丢失这部分增量,导致 Redis usage 比真实偏小。
|
||||
// key 不存在同样跳过(由下次 SetCache 重建)。
|
||||
// KEYS[1] = hash key
|
||||
// ARGV[1] = cost (string float)
|
||||
// ARGV[2] = ttl seconds
|
||||
// ARGV[3] = expected schema_version (Go 侧 UserPlatformQuotaCacheSchemaV1)
|
||||
const updateUserPlatformQuotaUsageScript = `
|
||||
if redis.call("EXISTS", KEYS[1]) == 0 then
|
||||
return 0
|
||||
end
|
||||
local ver = redis.call("HGET", KEYS[1], "schema_version")
|
||||
if ver == false or tonumber(ver) ~= tonumber(ARGV[3]) then
|
||||
return 0
|
||||
end
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "daily_usage", ARGV[1])
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "weekly_usage", ARGV[1])
|
||||
redis.call("HINCRBYFLOAT", KEYS[1], "monthly_usage", ARGV[1])
|
||||
redis.call("HINCRBY", KEYS[1], "version", 1)
|
||||
redis.call("EXPIRE", KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`
|
||||
|
||||
func (c *billingCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
key := userPlatformQuotaCacheKey(userID, platform)
|
||||
_, err := c.rdb.Eval(ctx, updateUserPlatformQuotaUsageScript, []string{key},
|
||||
strconv.FormatFloat(cost, 'f', -1, 64),
|
||||
int(ttl.Seconds()),
|
||||
service.UserPlatformQuotaCacheSchemaV1,
|
||||
).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func newMiniRedisCache(t *testing.T) (*billingCache, *miniredis.Miniredis) {
|
||||
t.Helper()
|
||||
mr := miniredis.RunT(t)
|
||||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
return &billingCache{rdb: rdb}, mr
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_GetMissReturnsNotFound(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
entry, ok, err := c.GetUserPlatformQuotaCache(context.Background(), 1, "anthropic")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok || entry != nil {
|
||||
t.Errorf("expected miss, got ok=%v entry=%v", ok, entry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_SetThenGet(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
dailyLimit := 20.0
|
||||
ts := time.Date(2024, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
in := &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 1.5,
|
||||
WeeklyUsageUSD: 3.0,
|
||||
MonthlyUsageUSD: 10.0,
|
||||
Version: 7,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
DailyLimitUSD: &dailyLimit,
|
||||
DailyWindowStart: &ts,
|
||||
}
|
||||
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("get: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if got.DailyUsageUSD != 1.5 || got.WeeklyUsageUSD != 3.0 || got.MonthlyUsageUSD != 10.0 || got.Version != 7 {
|
||||
t.Errorf("got = %+v, want %+v", got, in)
|
||||
}
|
||||
if got.SchemaVersion != service.UserPlatformQuotaCacheSchemaV1 {
|
||||
t.Errorf("SchemaVersion = %d, want %d", got.SchemaVersion, service.UserPlatformQuotaCacheSchemaV1)
|
||||
}
|
||||
if got.DailyLimitUSD == nil || *got.DailyLimitUSD != dailyLimit {
|
||||
t.Errorf("DailyLimitUSD = %v, want %v", got.DailyLimitUSD, dailyLimit)
|
||||
}
|
||||
if got.DailyWindowStart == nil || !got.DailyWindowStart.Equal(ts) {
|
||||
t.Errorf("DailyWindowStart = %v, want %v", got.DailyWindowStart, ts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_NilLimitSetThenGet(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
in := &service.UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 1.0,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
// DailyLimitUSD nil → 无限额
|
||||
}
|
||||
if err := c.SetUserPlatformQuotaCache(ctx, 1, "openai", in, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, ok, err := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("get: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if got.DailyLimitUSD != nil {
|
||||
t.Errorf("DailyLimitUSD should be nil for unlimited, got %v", got.DailyLimitUSD)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_IncrMissIsNoop(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(context.Background(), 1, "openai", 0.5, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok, _ := c.GetUserPlatformQuotaCache(context.Background(), 1, "openai")
|
||||
if ok {
|
||||
t.Error("expected key absent after no-op incr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_IncrHitAccumulates(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
// SchemaVersion 必须显式设为 V1,否则 Lua 脚本会因 schema 不匹配而 return 0,跳过累加。
|
||||
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{
|
||||
Version: 1,
|
||||
SchemaVersion: service.UserPlatformQuotaCacheSchemaV1,
|
||||
}, time.Minute)
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.5, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.IncrUserPlatformQuotaUsageCache(ctx, 1, "openai", 0.25, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, _, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if got.DailyUsageUSD != 0.75 || got.WeeklyUsageUSD != 0.75 || got.MonthlyUsageUSD != 0.75 {
|
||||
t.Errorf("got %+v, want daily/weekly/monthly=0.75", got)
|
||||
}
|
||||
if got.Version != 3 {
|
||||
t.Errorf("version = %d, want 3 (initial 1 + 2 incr)", got.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaCache_Delete(t *testing.T) {
|
||||
c, _ := newMiniRedisCache(t)
|
||||
ctx := context.Background()
|
||||
_ = c.SetUserPlatformQuotaCache(ctx, 1, "openai", &service.UserPlatformQuotaCacheEntry{Version: 1}, time.Minute)
|
||||
if err := c.DeleteUserPlatformQuotaCache(ctx, 1, "openai"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok, _ := c.GetUserPlatformQuotaCache(ctx, 1, "openai")
|
||||
if ok {
|
||||
t.Error("expected miss after delete")
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeRepoForAdapter struct {
|
||||
upsertCalledWith []UserPlatformQuotaRecord
|
||||
upsertCalledUserID int64
|
||||
upsertErr error
|
||||
resetCalledWith [4]any // userID, platform, window, newStart
|
||||
resetErr error
|
||||
}
|
||||
|
||||
func (f *fakeRepoForAdapter) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeRepoForAdapter) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeRepoForAdapter) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeRepoForAdapter) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeRepoForAdapter) ResetExpiredWindow(_ context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
f.resetCalledWith = [4]any{userID, platform, window, newStart}
|
||||
return f.resetErr
|
||||
}
|
||||
func (f *fakeRepoForAdapter) UpsertForUser(_ context.Context, userID int64, records []UserPlatformQuotaRecord) error {
|
||||
f.upsertCalledUserID = userID
|
||||
f.upsertCalledWith = records
|
||||
return f.upsertErr
|
||||
}
|
||||
|
||||
func TestGenericAdapter_UpsertForUser_ForwardsRecords(t *testing.T) {
|
||||
fake := &fakeRepoForAdapter{}
|
||||
adapter := NewUserPlatformQuotaServiceAdapter(fake)
|
||||
|
||||
err := adapter.UpsertForUser(context.Background(), 42, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if fake.upsertCalledUserID != 42 {
|
||||
t.Errorf("forwarded userID = %d, want 42", fake.upsertCalledUserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericAdapter_UpsertForUser_PropagatesError(t *testing.T) {
|
||||
wantErr := errors.New("boom")
|
||||
fake := &fakeRepoForAdapter{upsertErr: wantErr}
|
||||
adapter := NewUserPlatformQuotaServiceAdapter(fake)
|
||||
|
||||
err := adapter.UpsertForUser(context.Background(), 1, nil)
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("expected %v, got %v", wantErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericAdapter_ResetExpiredWindow_ForwardsAllParams(t *testing.T) {
|
||||
fake := &fakeRepoForAdapter{}
|
||||
adapter := NewUserPlatformQuotaServiceAdapter(fake)
|
||||
|
||||
now := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC)
|
||||
if err := adapter.ResetExpiredWindow(context.Background(), 7, "openai", "weekly", now); err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if fake.resetCalledWith[0].(int64) != 7 ||
|
||||
fake.resetCalledWith[1].(string) != "openai" ||
|
||||
fake.resetCalledWith[2].(string) != "weekly" ||
|
||||
!fake.resetCalledWith[3].(time.Time).Equal(now) {
|
||||
t.Errorf("forwarded params mismatch: %+v", fake.resetCalledWith)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericAdapter_ResetExpiredWindow_PropagatesError(t *testing.T) {
|
||||
wantErr := errors.New("not found")
|
||||
fake := &fakeRepoForAdapter{resetErr: wantErr}
|
||||
adapter := NewUserPlatformQuotaServiceAdapter(fake)
|
||||
|
||||
err := adapter.ResetExpiredWindow(context.Background(), 1, "a", "daily", time.Now())
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("expected %v, got %v", wantErr, err)
|
||||
}
|
||||
}
|
||||
416
backend/internal/repository/user_platform_quota_repo.go
Normal file
416
backend/internal/repository/user_platform_quota_repo.go
Normal file
@ -0,0 +1,416 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
|
||||
// UserPlatformQuotaRecord 是 repository 层的传输结构体,
|
||||
// 与 ent.UserPlatformQuota 实体解耦,供业务层使用。
|
||||
type UserPlatformQuotaRecord struct {
|
||||
UserID int64
|
||||
Platform string
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
DailyUsageUSD float64
|
||||
WeeklyUsageUSD float64
|
||||
MonthlyUsageUSD float64
|
||||
DailyWindowStart *time.Time
|
||||
WeeklyWindowStart *time.Time
|
||||
MonthlyWindowStart *time.Time
|
||||
}
|
||||
|
||||
// ErrUserPlatformQuotaNotFound 用于 ResetExpiredWindow 等需要"必须命中已有记录"的方法。
|
||||
var ErrUserPlatformQuotaNotFound = fmt.Errorf("user platform quota record not found")
|
||||
|
||||
// UserPlatformQuotaRepository 定义用户平台配额的数据访问接口。
|
||||
type UserPlatformQuotaRepository interface {
|
||||
// BulkInsertInitial 幂等批量插入初始配额记录(ON CONFLICT DO NOTHING)。
|
||||
BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error
|
||||
// GetByUserPlatform 查询单条配额记录,未找到时返回 (nil, nil)。
|
||||
GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error)
|
||||
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
|
||||
ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error)
|
||||
// IncrementUsageWithReset 原子地累加用量,若窗口已过期则先重置再累加。
|
||||
IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error
|
||||
// ResetExpiredWindow 重置指定窗口(daily/weekly/monthly)的用量与起始时间。
|
||||
ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error
|
||||
// UpsertForUser 全量替换该用户所有平台限额配置(详见 service.UserPlatformQuotaRepository.UpsertForUser)。
|
||||
UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error
|
||||
}
|
||||
|
||||
type userPlatformQuotaRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
// NewUserPlatformQuotaRepository 创建 UserPlatformQuotaRepository 实现。
|
||||
func NewUserPlatformQuotaRepository(client *dbent.Client) UserPlatformQuotaRepository {
|
||||
return &userPlatformQuotaRepository{client: client}
|
||||
}
|
||||
|
||||
// BulkInsertInitial 用原生 SQL ON CONFLICT 实现幂等批量插入(带条件 limit 覆盖)。
|
||||
// 仅插入 limit_usd 与元数据,usage_usd 用 DB 默认 0,window_start 留 NULL。
|
||||
// FK 约束要求 user_id 在 users 表中存在,调用方负责保证。
|
||||
//
|
||||
// 冲突策略:CASE WHEN existing.*_limit_usd IS NULL THEN EXCLUDED.*_limit_usd ELSE existing ...
|
||||
// - 若 IncrementUsageWithReset 因时序问题已先建行(limit 全 NULL),
|
||||
// 此处会把注册时的默认 limit 写入,避免该用户在该平台永久无限额。
|
||||
// - 若管理员已通过 UpsertForUser 设置了非 NULL 个性化 limit,**保留不动**
|
||||
// —— 旧实现无条件 EXCLUDED 覆盖会丢失个性化配置。
|
||||
// - 不会改 usage_usd / window_start,保留累计的用量。
|
||||
// - 仅命中 deleted_at IS NULL 的活跃记录(partial unique index 作用域)。
|
||||
func (r *userPlatformQuotaRepository) BulkInsertInitial(ctx context.Context, records []UserPlatformQuotaRecord) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("INSERT INTO user_platform_quotas (user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd, daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at) VALUES ")
|
||||
args := make([]any, 0, len(records)*6)
|
||||
// 统一时间戳:避免循环内多次 time.Now() 让同一批记录的 created_at/updated_at
|
||||
// 出现亚毫秒级偏差(与 UpsertForUser 的 now := time.Now() 风格一致)。
|
||||
now := time.Now()
|
||||
for i, rec := range records {
|
||||
base := i * 6
|
||||
if i > 0 {
|
||||
_, _ = sb.WriteString(",")
|
||||
}
|
||||
fmt.Fprintf(&sb, "($%d,$%d,$%d,$%d,$%d,0,0,0,$%d,$%d)",
|
||||
base+1, base+2, base+3, base+4, base+5, base+6, base+6)
|
||||
args = append(args,
|
||||
rec.UserID, rec.Platform,
|
||||
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
|
||||
now,
|
||||
)
|
||||
}
|
||||
// 精确命中 partial unique index(deleted_at IS NULL),避免对软删记录的歧义冲突。
|
||||
// 条件覆盖:仅在现有 limit 为 NULL 时才写入 EXCLUDED,否则保留现有非 NULL 值。
|
||||
// - 修复 IncrementUsageWithReset 已用 NULL limit 建行的场景(NULL → 注册默认)
|
||||
// - 保护管理员通过 UpsertForUser 设置的个性化 limit 不被静默覆盖
|
||||
_, _ = sb.WriteString(` ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL
|
||||
DO UPDATE SET
|
||||
daily_limit_usd = COALESCE(user_platform_quotas.daily_limit_usd, EXCLUDED.daily_limit_usd),
|
||||
weekly_limit_usd = COALESCE(user_platform_quotas.weekly_limit_usd, EXCLUDED.weekly_limit_usd),
|
||||
monthly_limit_usd = COALESCE(user_platform_quotas.monthly_limit_usd, EXCLUDED.monthly_limit_usd),
|
||||
updated_at = EXCLUDED.updated_at`)
|
||||
|
||||
_, err := client.ExecContext(ctx, sb.String(), args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByUserPlatform 通过 ent 查询单条配额(排除软删除)。未找到返回 (nil, nil)。
|
||||
func (r *userPlatformQuotaRepository) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaRecord, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
entity, err := client.UserPlatformQuota.Query().
|
||||
Where(
|
||||
userplatformquota.UserIDEQ(userID),
|
||||
userplatformquota.PlatformEQ(platform),
|
||||
userplatformquota.DeletedAtIsNil(),
|
||||
).
|
||||
Only(ctx)
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entQuotaToRecord(entity), nil
|
||||
}
|
||||
|
||||
// ListByUser 查询用户的所有平台配额记录(排除软删除)。
|
||||
func (r *userPlatformQuotaRepository) ListByUser(ctx context.Context, userID int64) ([]UserPlatformQuotaRecord, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.UserPlatformQuota.Query().
|
||||
Where(
|
||||
userplatformquota.UserIDEQ(userID),
|
||||
userplatformquota.DeletedAtIsNil(),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]UserPlatformQuotaRecord, 0, len(rows))
|
||||
for _, e := range rows {
|
||||
out = append(out, *entQuotaToRecord(e))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的 *_usage_usd。
|
||||
// 行为:
|
||||
// - 若记录存在:在事务内 SELECT FOR UPDATE,按 (prev_window_start vs current_window_start)
|
||||
// 判断是否需要重置(不同 = 重置为 cost;相同 = 累加 cost)
|
||||
// - 若记录不存在(fail-open create 分支):插入新记录,**limit 字段保留 nil(无限制)**
|
||||
// —— 这是预期行为:billing 链路不能因 quota 表缺失而阻断请求,未注册路径
|
||||
// 的用户 quota 默认放行,由调度层指标观测 + 后台对账补建 limit
|
||||
//
|
||||
// 上层正常路径(注册时 BulkInsertInitial)保证 limit 在记录创建时就被写入。
|
||||
func (r *userPlatformQuotaRepository) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
existing, err := txClient.UserPlatformQuota.Query().
|
||||
Where(
|
||||
userplatformquota.UserIDEQ(userID),
|
||||
userplatformquota.PlatformEQ(platform),
|
||||
userplatformquota.DeletedAtIsNil(),
|
||||
).
|
||||
ForUpdate().
|
||||
Only(txCtx)
|
||||
if dbent.IsNotFound(err) {
|
||||
// fail-open 建行:limit_* 保留 NULL(无限额)。
|
||||
// 用 ON CONFLICT DO UPDATE 累加,而非裸 INSERT:并发下另一请求可能在本事务
|
||||
// SELECT FOR UPDATE 之后、INSERT 之前刚建行,裸 INSERT 会撞 partial unique index
|
||||
// 致事务回滚、本次 cost 丢失;DO UPDATE 把 cost 累加到既有 usage 上。
|
||||
// 写法与本文件 insertLimitsRow / BulkInsertInitial 的 ON CONFLICT 一致。
|
||||
const insertSQL = `INSERT INTO user_platform_quotas
|
||||
(user_id, platform, daily_usage_usd, weekly_usage_usd, monthly_usage_usd,
|
||||
daily_window_start, weekly_window_start, monthly_window_start, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $3, $3, $4, $5, $6, $7, $7)
|
||||
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO UPDATE SET
|
||||
daily_usage_usd = user_platform_quotas.daily_usage_usd + EXCLUDED.daily_usage_usd,
|
||||
weekly_usage_usd = user_platform_quotas.weekly_usage_usd + EXCLUDED.weekly_usage_usd,
|
||||
monthly_usage_usd = user_platform_quotas.monthly_usage_usd + EXCLUDED.monthly_usage_usd,
|
||||
updated_at = EXCLUDED.updated_at`
|
||||
// $6 = now:30 天滚动月度窗口以当前时刻为起始
|
||||
_, e := txClient.ExecContext(txCtx, insertSQL,
|
||||
userID, platform, cost,
|
||||
timezone.StartOfDay(now), timezone.StartOfWeek(now), now, now)
|
||||
return e
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newDaily := maybeReset(existing.DailyUsageUsd, existing.DailyWindowStart, timezone.StartOfDay(now), cost)
|
||||
newWeekly := maybeReset(existing.WeeklyUsageUsd, existing.WeeklyWindowStart, timezone.StartOfWeek(now), cost)
|
||||
// 30 天滚动月度窗口:过期时重置为 cost 并以 now 为新起始,否则累加保留原起始
|
||||
newMonthly, newMonthlyStart := monthlyMaybeReset(existing.MonthlyUsageUsd, existing.MonthlyWindowStart, cost, now)
|
||||
|
||||
_, e := existing.Update().
|
||||
SetDailyUsageUsd(newDaily).
|
||||
SetWeeklyUsageUsd(newWeekly).
|
||||
SetMonthlyUsageUsd(newMonthly).
|
||||
SetDailyWindowStart(timezone.StartOfDay(now)).
|
||||
SetWeeklyWindowStart(timezone.StartOfWeek(now)).
|
||||
SetMonthlyWindowStart(newMonthlyStart). // 30 天滚动:仅过期时更新起始
|
||||
Save(txCtx)
|
||||
return e
|
||||
})
|
||||
}
|
||||
|
||||
// ResetExpiredWindow 无条件重置指定窗口(daily/weekly/monthly)的用量与起始时间。
|
||||
//
|
||||
// ⚠️ 命名警告(NOT a "check-then-reset" helper):
|
||||
//
|
||||
// 名字里的 "Expired" 是历史遗留,**实现并不校验窗口是否真的过期**。
|
||||
// 任何调用都会无条件把对应窗口的 *_usage_usd 清零并重写 *_window_start。
|
||||
// 目前唯一合法 caller 是 admin POST /reset 接口(管理员强制归零)。
|
||||
//
|
||||
// 如果你想要"仅在窗口过期才重置"的语义,请直接使用 IncrementUsageWithReset
|
||||
// 的内部判断(maybeReset / monthlyMaybeReset),或新增独立函数;
|
||||
// 不要复用这里的函数,否则会出现"明明窗口未过期,用量却被清零"的隐蔽 bug。
|
||||
//
|
||||
// 未命中活跃记录时返回 ErrUserPlatformQuotaNotFound。
|
||||
func (r *userPlatformQuotaRepository) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
upd := client.UserPlatformQuota.Update().
|
||||
Where(
|
||||
userplatformquota.UserIDEQ(userID),
|
||||
userplatformquota.PlatformEQ(platform),
|
||||
userplatformquota.DeletedAtIsNil(),
|
||||
)
|
||||
switch window {
|
||||
case "daily":
|
||||
upd = upd.SetDailyUsageUsd(0).SetDailyWindowStart(newStart)
|
||||
case "weekly":
|
||||
upd = upd.SetWeeklyUsageUsd(0).SetWeeklyWindowStart(newStart)
|
||||
case "monthly":
|
||||
upd = upd.SetMonthlyUsageUsd(0).SetMonthlyWindowStart(newStart)
|
||||
default:
|
||||
return fmt.Errorf("unknown window %q", window)
|
||||
}
|
||||
n, err := upd.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrUserPlatformQuotaNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// withTx 在事务中执行 fn,若 ctx 中已有事务则复用。
|
||||
func (r *userPlatformQuotaRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin user_platform_quota transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := fn(txCtx, tx.Client()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit user_platform_quota transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// entQuotaToRecord 将 ent entity 映射为 repository record。
|
||||
// 注意 ent 生成字段名为 DailyLimitUsd(非 DailyLimitUSD)。
|
||||
func entQuotaToRecord(e *dbent.UserPlatformQuota) *UserPlatformQuotaRecord {
|
||||
return &UserPlatformQuotaRecord{
|
||||
UserID: e.UserID,
|
||||
Platform: e.Platform,
|
||||
DailyLimitUSD: e.DailyLimitUsd,
|
||||
WeeklyLimitUSD: e.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: e.MonthlyLimitUsd,
|
||||
DailyUsageUSD: e.DailyUsageUsd,
|
||||
WeeklyUsageUSD: e.WeeklyUsageUsd,
|
||||
MonthlyUsageUSD: e.MonthlyUsageUsd,
|
||||
DailyWindowStart: e.DailyWindowStart,
|
||||
WeeklyWindowStart: e.WeeklyWindowStart,
|
||||
MonthlyWindowStart: e.MonthlyWindowStart,
|
||||
}
|
||||
}
|
||||
|
||||
// maybeReset 判断是否需要重置窗口用量:
|
||||
// - 若 prevStart 为 nil 或与 currStart 不同,表示窗口已过期,返回 cost(重置)
|
||||
// - 否则返回 prevUsage + cost(累加)
|
||||
func maybeReset(prevUsage float64, prevStart *time.Time, currStart time.Time, cost float64) float64 {
|
||||
if prevStart == nil || !prevStart.Equal(currStart) {
|
||||
return cost
|
||||
}
|
||||
return prevUsage + cost
|
||||
}
|
||||
|
||||
// monthlyMaybeReset 判断 30 天滚动月度窗口是否需要重置。
|
||||
// 过期条件:prevStart 为 nil 或 now - prevStart >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。
|
||||
// 过期时重置为 cost,否则累加。返回 (newUsage, newWindowStart)。
|
||||
func monthlyMaybeReset(prevUsage float64, prevStart *time.Time, cost float64, now time.Time) (float64, time.Time) {
|
||||
if prevStart == nil || now.Sub(*prevStart) >= 30*24*time.Hour {
|
||||
return cost, now
|
||||
}
|
||||
return prevUsage + cost, *prevStart
|
||||
}
|
||||
|
||||
// UpsertForUser 全量替换该用户的所有平台限额(事务内):
|
||||
// 1. 软删除未在 records 中出现的所有 active 行
|
||||
// 2. 对每条 record 尝试 UPDATE(含 deleted_at = NULL 兼容重激活);
|
||||
// UPDATE 行数为 0 时 INSERT 新行
|
||||
//
|
||||
// 仅改 *_limit_usd + deleted_at + updated_at,保留 *_usage_usd / *_window_start。
|
||||
func (r *userPlatformQuotaRepository) UpsertForUser(ctx context.Context, userID int64, records []UserPlatformQuotaRecord) error {
|
||||
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
platforms := make([]string, 0, len(records))
|
||||
for _, rec := range records {
|
||||
platforms = append(platforms, rec.Platform)
|
||||
}
|
||||
now := time.Now()
|
||||
if err := softDeleteMissingPlatforms(txCtx, txClient, userID, platforms, now); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rec := range records {
|
||||
affected, err := updateLimitsRow(txCtx, txClient, userID, rec, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
if err := insertLimitsRow(txCtx, txClient, userID, rec, now); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// softDeleteMissingPlatforms 软删除该用户所有不在 keepPlatforms 中的 active 行。
|
||||
// keepPlatforms 为空时 → 软删用户所有 active 行。
|
||||
// now 由调用方传入,与 updateLimitsRow / insertLimitsRow 共享同一个 Go time.Now(),
|
||||
// 保证事务内所有时间戳一致(避免 Postgres NOW() 与 Go time.Now() 的微小偏差)。
|
||||
func softDeleteMissingPlatforms(ctx context.Context, client *dbent.Client, userID int64, keepPlatforms []string, now time.Time) error {
|
||||
var (
|
||||
query string
|
||||
args []any
|
||||
)
|
||||
if len(keepPlatforms) == 0 {
|
||||
query = `UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
|
||||
WHERE user_id = $1 AND deleted_at IS NULL`
|
||||
args = []any{userID, now}
|
||||
} else {
|
||||
placeholders := make([]string, len(keepPlatforms))
|
||||
args = make([]any, 0, len(keepPlatforms)+2)
|
||||
args = append(args, userID, now)
|
||||
for i, p := range keepPlatforms {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+3)
|
||||
args = append(args, p)
|
||||
}
|
||||
query = fmt.Sprintf(`UPDATE user_platform_quotas SET deleted_at = $2, updated_at = $2
|
||||
WHERE user_id = $1 AND deleted_at IS NULL AND platform NOT IN (%s)`,
|
||||
strings.Join(placeholders, ","))
|
||||
}
|
||||
_, err := client.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// updateLimitsRow 尝试 UPDATE active 行(deleted_at IS NULL),返回受影响行数。
|
||||
// 仅更新 active 行:若存在多条历史软删记录,加 deleted_at IS NULL 守卫可避免
|
||||
// 批量重激活导致的 partial unique index(userplatformquota_user_id_platform_uq)冲突。
|
||||
// affected=0 时由调用方 UpsertForUser 走 insertLimitsRow 路径创建新行。
|
||||
func updateLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) (int64, error) {
|
||||
const query = `UPDATE user_platform_quotas
|
||||
SET daily_limit_usd = $1, weekly_limit_usd = $2, monthly_limit_usd = $3,
|
||||
deleted_at = NULL, updated_at = $4
|
||||
WHERE user_id = $5 AND platform = $6 AND deleted_at IS NULL`
|
||||
res, err := client.ExecContext(ctx, query,
|
||||
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD, now,
|
||||
userID, rec.Platform)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// insertLimitsRow 插入新限额行(usage 默认 0,window_start 默认 NULL)。
|
||||
// 带 ON CONFLICT ... DO NOTHING 守卫:防止两个并发请求同时为同一 user/platform 新建行时
|
||||
// 触发 unique constraint 违反(userplatformquota_user_id_platform_uq 部分唯一索引)。
|
||||
// affected=0 时说明另一个并发请求刚完成 INSERT,fallback 到 updateLimitsRow 覆写 limits 值。
|
||||
func insertLimitsRow(ctx context.Context, client *dbent.Client, userID int64, rec UserPlatformQuotaRecord, now time.Time) error {
|
||||
const query = `INSERT INTO user_platform_quotas
|
||||
(user_id, platform, daily_limit_usd, weekly_limit_usd, monthly_limit_usd,
|
||||
daily_usage_usd, weekly_usage_usd, monthly_usage_usd, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, 0, 0, 0, $6, $6)
|
||||
ON CONFLICT (user_id, platform) WHERE deleted_at IS NULL DO NOTHING`
|
||||
res, err := client.ExecContext(ctx, query,
|
||||
userID, rec.Platform,
|
||||
rec.DailyLimitUSD, rec.WeeklyLimitUSD, rec.MonthlyLimitUSD,
|
||||
now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// 并发情形:另一请求已插入该行,fallback 到 UPDATE 覆写 limits 值(last-writer-wins)。
|
||||
_, err = updateLimitsRow(ctx, client, userID, rec, now)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,269 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mustCreateUserForQuota 在指定 client 上创建测试用户(满足 FK 约束)。
|
||||
func mustCreateUserForQuota(t *testing.T, client *dbent.Client) int64 {
|
||||
t.Helper()
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("quota-test-%d@example.com", time.Now().UnixNano()),
|
||||
})
|
||||
return u.ID
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_BulkInsertInitial_Idempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
daily := 5.0
|
||||
records := []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily},
|
||||
{UserID: userID, Platform: "openai"},
|
||||
}
|
||||
|
||||
// 第一次插入
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "first insert")
|
||||
// 第二次插入应为 no-op(ON CONFLICT DO NOTHING)
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "second insert (idempotent)")
|
||||
|
||||
list, err := repo.ListByUser(txCtx, userID)
|
||||
require.NoError(t, err, "list")
|
||||
require.Len(t, list, 2, "expected 2 records after idempotent insert")
|
||||
|
||||
// 校验 daily_limit_usd 保留
|
||||
var anthropicRec *UserPlatformQuotaRecord
|
||||
for i := range list {
|
||||
if list[i].Platform == "anthropic" {
|
||||
anthropicRec = &list[i]
|
||||
}
|
||||
}
|
||||
require.NotNil(t, anthropicRec, "anthropic record should exist")
|
||||
require.NotNil(t, anthropicRec.DailyLimitUSD, "daily limit should be set")
|
||||
require.InDelta(t, 5.0, *anthropicRec.DailyLimitUSD, 1e-9, "daily limit should be 5.0")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_BulkInsertInitial_Empty(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
// 空切片不应报错
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, nil))
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{}))
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_GetByUserPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
// 未插入时应返回 nil
|
||||
rec, err := repo.GetByUserPlatform(txCtx, userID, "anthropic")
|
||||
require.NoError(t, err, "get before insert should not error")
|
||||
require.Nil(t, rec, "get before insert should return nil")
|
||||
|
||||
// 插入后查询
|
||||
daily := 10.0
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily},
|
||||
}))
|
||||
|
||||
rec, err = repo.GetByUserPlatform(txCtx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rec)
|
||||
require.Equal(t, userID, rec.UserID)
|
||||
require.Equal(t, "anthropic", rec.Platform)
|
||||
require.NotNil(t, rec.DailyLimitUSD)
|
||||
require.InDelta(t, 10.0, *rec.DailyLimitUSD, 1e-9)
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_SameWindow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// IncrementUsageWithReset 内部自己开事务,使用独立 ent client 确保跨事务可见
|
||||
client := testEntClient(t)
|
||||
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五
|
||||
|
||||
// 首次调用:应新建记录
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.5, now))
|
||||
|
||||
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rec)
|
||||
require.InDelta(t, 1.5, rec.DailyUsageUSD, 1e-9, "initial daily usage")
|
||||
require.InDelta(t, 1.5, rec.WeeklyUsageUSD, 1e-9, "initial weekly usage")
|
||||
require.InDelta(t, 1.5, rec.MonthlyUsageUSD, 1e-9, "initial monthly usage")
|
||||
|
||||
// 同一天再次调用:应累加
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 0.5, now))
|
||||
|
||||
rec, err = repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "accumulated daily usage")
|
||||
require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "accumulated weekly usage")
|
||||
require.InDelta(t, 2.0, rec.MonthlyUsageUSD, 1e-9, "accumulated monthly usage")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_DailyReset(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
day1 := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC) // 周五(同一周、同一月)
|
||||
day2 := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) // 周六(同一周、同一月)
|
||||
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.0, day1))
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 1.0, day2))
|
||||
|
||||
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.0, rec.DailyUsageUSD, 1e-9, "daily should reset to 1.0")
|
||||
require.InDelta(t, 4.0, rec.WeeklyUsageUSD, 1e-9, "weekly should accumulate to 4.0 (same week)")
|
||||
require.InDelta(t, 4.0, rec.MonthlyUsageUSD, 1e-9, "monthly should accumulate to 4.0 (same month)")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_IncrementUsageWithReset_WeeklyReset(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
// 5月22日(周五)和 5月25日(下周一),不同周
|
||||
fri := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC)
|
||||
nextMon := time.Date(2026, 5, 25, 10, 0, 0, 0, time.UTC) // 下一周周一
|
||||
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 5.0, fri))
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "openai", 2.0, nextMon))
|
||||
|
||||
rec, err := repo.GetByUserPlatform(ctx, userID, "openai")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.0, rec.DailyUsageUSD, 1e-9, "daily resets to new cost")
|
||||
require.InDelta(t, 2.0, rec.WeeklyUsageUSD, 1e-9, "weekly resets (new week)")
|
||||
require.InDelta(t, 7.0, rec.MonthlyUsageUSD, 1e-9, "monthly accumulates (same month)")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_ResetExpiredWindow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
// 先通过 ent 直接建一条记录
|
||||
_, err := client.UserPlatformQuota.Create().
|
||||
SetUserID(userID).
|
||||
SetPlatform("gemini").
|
||||
SetDailyUsageUsd(10.0).
|
||||
SetWeeklyUsageUsd(20.0).
|
||||
SetMonthlyUsageUsd(50.0).
|
||||
SetDailyWindowStart(time.Date(2026, 5, 21, 0, 0, 0, 0, time.UTC)).
|
||||
SetWeeklyWindowStart(time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)).
|
||||
SetMonthlyWindowStart(time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)).
|
||||
Save(txCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
newStart := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, repo.ResetExpiredWindow(txCtx, userID, "gemini", "daily", newStart))
|
||||
|
||||
rec, err := repo.GetByUserPlatform(txCtx, userID, "gemini")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 0.0, rec.DailyUsageUSD, 1e-9, "daily usage reset to 0")
|
||||
require.NotNil(t, rec.DailyWindowStart)
|
||||
require.True(t, rec.DailyWindowStart.Equal(newStart), "daily window start updated")
|
||||
// 其他窗口不变
|
||||
require.InDelta(t, 20.0, rec.WeeklyUsageUSD, 1e-9, "weekly usage unchanged")
|
||||
require.InDelta(t, 50.0, rec.MonthlyUsageUSD, 1e-9, "monthly usage unchanged")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_ResetExpiredWindow_UnknownWindow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
err := repo.ResetExpiredWindow(ctx, 999, "anthropic", "yearly", time.Now())
|
||||
require.Error(t, err, "unknown window should return error")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_BulkInsertInitial_MultiRow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
d1, d2, d3 := 5.0, 10.0, 15.0
|
||||
records := []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1},
|
||||
{UserID: userID, Platform: "openai", DailyLimitUSD: &d2},
|
||||
{UserID: userID, Platform: "gemini", DailyLimitUSD: &d3},
|
||||
}
|
||||
require.NoError(t, repo.BulkInsertInitial(txCtx, records), "multi-row insert failed")
|
||||
|
||||
list, err := repo.ListByUser(txCtx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, list, 3, "expected 3 rows, got %d", len(list))
|
||||
|
||||
// 验证 limit 值与传入一致(防占位符串位)
|
||||
byPlatform := map[string]*UserPlatformQuotaRecord{}
|
||||
for i := range list {
|
||||
byPlatform[list[i].Platform] = &list[i]
|
||||
}
|
||||
require.NotNil(t, byPlatform["anthropic"], "anthropic record should exist")
|
||||
require.NotNil(t, byPlatform["anthropic"].DailyLimitUSD, "anthropic daily limit should be set")
|
||||
require.InDelta(t, 5.0, *byPlatform["anthropic"].DailyLimitUSD, 1e-9, "anthropic daily_limit = want 5.0")
|
||||
|
||||
require.NotNil(t, byPlatform["openai"], "openai record should exist")
|
||||
require.NotNil(t, byPlatform["openai"].DailyLimitUSD, "openai daily limit should be set")
|
||||
require.InDelta(t, 10.0, *byPlatform["openai"].DailyLimitUSD, 1e-9, "openai daily_limit = want 10.0")
|
||||
|
||||
require.NotNil(t, byPlatform["gemini"], "gemini record should exist")
|
||||
require.NotNil(t, byPlatform["gemini"].DailyLimitUSD, "gemini daily limit should be set")
|
||||
require.InDelta(t, 15.0, *byPlatform["gemini"].DailyLimitUSD, 1e-9, "gemini daily_limit = want 15.0")
|
||||
}
|
||||
|
||||
func TestUserPlatformQuotaRepository_ResetExpiredWindow_NotFoundReturnsSentinel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
err := repo.ResetExpiredWindow(ctx, 99999, "anthropic", "daily", time.Now())
|
||||
require.True(t, errors.Is(err, ErrUserPlatformQuotaNotFound),
|
||||
"expected ErrUserPlatformQuotaNotFound, got %v", err)
|
||||
}
|
||||
103
backend/internal/repository/user_platform_quota_repo_test.go
Normal file
103
backend/internal/repository/user_platform_quota_repo_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMaybeReset(t *testing.T) {
|
||||
start := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
|
||||
other := start.AddDate(0, 0, -1)
|
||||
cases := []struct {
|
||||
name string
|
||||
prevUsage float64
|
||||
prevStart *time.Time
|
||||
currStart time.Time
|
||||
cost float64
|
||||
want float64
|
||||
}{
|
||||
{"nil prev start resets", 10, nil, start, 1.5, 1.5},
|
||||
{"different start resets", 10, &other, start, 1.5, 1.5},
|
||||
{"same start accumulates", 10, &start, start, 1.5, 11.5},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
if got := maybeReset(c.prevUsage, c.prevStart, c.currStart, c.cost); got != c.want {
|
||||
t.Errorf("maybeReset = %v, want %v", got, c.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonthlyMaybeReset_NilStart 验证 prevStart=nil 时重置。
|
||||
func TestMonthlyMaybeReset_NilStart(t *testing.T) {
|
||||
now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC)
|
||||
usage, start := monthlyMaybeReset(10.0, nil, 1.5, now)
|
||||
if usage != 1.5 {
|
||||
t.Errorf("usage = %v, want 1.5", usage)
|
||||
}
|
||||
if !start.Equal(now) {
|
||||
t.Errorf("start = %v, want %v", start, now)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonthlyMaybeReset_Expired 验证窗口满 30 天时重置(30 天恰好到期)。
|
||||
func TestMonthlyMaybeReset_Expired(t *testing.T) {
|
||||
windowStart := time.Date(2026, 4, 22, 12, 0, 0, 0, time.UTC)
|
||||
// now = windowStart + 30d(刚好到期)
|
||||
now := windowStart.Add(30 * 24 * time.Hour)
|
||||
usage, start := monthlyMaybeReset(8.0, &windowStart, 2.0, now)
|
||||
if usage != 2.0 {
|
||||
t.Errorf("usage = %v, want 2.0 (reset)", usage)
|
||||
}
|
||||
if !start.Equal(now) {
|
||||
t.Errorf("start = %v, want %v (new window)", start, now)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonthlyMaybeReset_CrossMonthBoundary 验证跨自然月时也使用 30 天滚动(不提前重置)。
|
||||
// 旧行为:5 月 1 日跨月立即重置;新行为:窗口起始 4 月 20 日,5 月 1 日仅过了 11 天,应累加。
|
||||
func TestMonthlyMaybeReset_CrossMonthBoundary(t *testing.T) {
|
||||
windowStart := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
|
||||
// 5 月 1 日:距起始 11 天,不足 30 天,应累加
|
||||
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
usage, start := monthlyMaybeReset(5.0, &windowStart, 1.0, now)
|
||||
if usage != 6.0 {
|
||||
t.Errorf("usage = %v, want 6.0 (accumulate, not reset at month boundary)", usage)
|
||||
}
|
||||
if !start.Equal(windowStart) {
|
||||
t.Errorf("start = %v, want %v (preserved)", start, windowStart)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonthlyMaybeReset_Active 验证窗口内正常累加。
|
||||
func TestMonthlyMaybeReset_Active(t *testing.T) {
|
||||
windowStart := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
// 15 天内,窗口有效
|
||||
now := windowStart.Add(15 * 24 * time.Hour)
|
||||
usage, start := monthlyMaybeReset(3.0, &windowStart, 0.5, now)
|
||||
if usage != 3.5 {
|
||||
t.Errorf("usage = %v, want 3.5", usage)
|
||||
}
|
||||
if !start.Equal(windowStart) {
|
||||
t.Errorf("start = %v, want %v", start, windowStart)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateLimitsRowQuery_HasDeletedAtGuard 通过读取源文件验证 updateLimitsRow
|
||||
// 的 SQL WHERE 子句包含 deleted_at IS NULL 守卫(I-NEW-1)。
|
||||
// 此防回归测试可在无 DB 的 CI 环境中运行,防止意外删除该守卫。
|
||||
func TestUpdateLimitsRowQuery_HasDeletedAtGuard(t *testing.T) {
|
||||
src, err := os.ReadFile("user_platform_quota_repo.go")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read source file: %v", err)
|
||||
}
|
||||
const guard = "AND deleted_at IS NULL"
|
||||
if !strings.Contains(string(src), guard) {
|
||||
t.Errorf("updateLimitsRow SQL must contain %q to prevent bulk reactivation of soft-deleted rows (I-NEW-1)", guard)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,206 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// userPlatformQuotaServiceAdapter 将 repository 层的 userPlatformQuotaRepository
|
||||
// 适配为 service.UserPlatformQuotaRepository 接口(返回 *service.UserPlatformQuotaRecord)。
|
||||
type userPlatformQuotaServiceAdapter struct {
|
||||
inner *userPlatformQuotaRepository
|
||||
}
|
||||
|
||||
// NewUserPlatformQuotaServiceAdapter 将 UserPlatformQuotaRepository 实现包装为
|
||||
// 满足 service.UserPlatformQuotaRepository 接口的适配器。
|
||||
func NewUserPlatformQuotaServiceAdapter(repo UserPlatformQuotaRepository) service.UserPlatformQuotaRepository {
|
||||
impl, ok := repo.(*userPlatformQuotaRepository)
|
||||
if !ok {
|
||||
// 非标准实现(如测试 fake),通过通用适配器包装
|
||||
return &genericUserPlatformQuotaAdapter{inner: repo}
|
||||
}
|
||||
return &userPlatformQuotaServiceAdapter{inner: impl}
|
||||
}
|
||||
|
||||
func (a *userPlatformQuotaServiceAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) {
|
||||
rec, err := a.inner.GetByUserPlatform(ctx, userID, platform)
|
||||
if err != nil || rec == nil {
|
||||
return nil, err
|
||||
}
|
||||
return toServiceRecord(rec), nil
|
||||
}
|
||||
|
||||
// IncrementUsageWithReset 原子累加 cost 到 (user, platform) 三个窗口的用量。
|
||||
func (a *userPlatformQuotaServiceAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
|
||||
return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now)
|
||||
}
|
||||
|
||||
// ListByUser 查询用户的所有平台配额记录。
|
||||
func (a *userPlatformQuotaServiceAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
rows, err := a.inner.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.UserPlatformQuotaRecord, len(rows))
|
||||
for i, r := range rows {
|
||||
out[i] = service.UserPlatformQuotaRecord{
|
||||
UserID: r.UserID,
|
||||
Platform: r.Platform,
|
||||
DailyLimitUSD: r.DailyLimitUSD,
|
||||
WeeklyLimitUSD: r.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: r.MonthlyLimitUSD,
|
||||
DailyUsageUSD: r.DailyUsageUSD,
|
||||
WeeklyUsageUSD: r.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: r.MonthlyUsageUSD,
|
||||
DailyWindowStart: r.DailyWindowStart,
|
||||
WeeklyWindowStart: r.WeeklyWindowStart,
|
||||
MonthlyWindowStart: r.MonthlyWindowStart,
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 repo。
|
||||
func (a *userPlatformQuotaServiceAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error {
|
||||
repoRecords := make([]UserPlatformQuotaRecord, len(records))
|
||||
for i, r := range records {
|
||||
repoRecords[i] = UserPlatformQuotaRecord{
|
||||
UserID: r.UserID,
|
||||
Platform: r.Platform,
|
||||
DailyLimitUSD: r.DailyLimitUSD,
|
||||
WeeklyLimitUSD: r.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: r.MonthlyLimitUSD,
|
||||
}
|
||||
}
|
||||
return a.inner.BulkInsertInitial(ctx, repoRecords)
|
||||
}
|
||||
|
||||
// UpsertForUser 全量替换该用户所有平台限额。
|
||||
func (a *userPlatformQuotaServiceAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
|
||||
repoRecords := toRepoRecords(records)
|
||||
return a.inner.UpsertForUser(ctx, userID, repoRecords)
|
||||
}
|
||||
|
||||
// ResetExpiredWindow 转发至 repository.ResetExpiredWindow,并将 repository sentinel 包装为 service sentinel。
|
||||
func (a *userPlatformQuotaServiceAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart)
|
||||
if errors.Is(err, ErrUserPlatformQuotaNotFound) {
|
||||
return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// genericUserPlatformQuotaAdapter 通过通用接口适配(用于测试 fake 或非标准实现)。
|
||||
type genericUserPlatformQuotaAdapter struct {
|
||||
inner UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
func (a *genericUserPlatformQuotaAdapter) GetByUserPlatform(ctx context.Context, userID int64, platform string) (*service.UserPlatformQuotaRecord, error) {
|
||||
rec, err := a.inner.GetByUserPlatform(ctx, userID, platform)
|
||||
if err != nil || rec == nil {
|
||||
return nil, err
|
||||
}
|
||||
return toServiceRecord(rec), nil
|
||||
}
|
||||
|
||||
// IncrementUsageWithReset 原子累加 cost(通用 adapter 实现)。
|
||||
func (a *genericUserPlatformQuotaAdapter) IncrementUsageWithReset(ctx context.Context, userID int64, platform string, cost float64, now time.Time) error {
|
||||
return a.inner.IncrementUsageWithReset(ctx, userID, platform, cost, now)
|
||||
}
|
||||
|
||||
// ListByUser 查询用户的所有平台配额记录(通用 adapter 实现)。
|
||||
func (a *genericUserPlatformQuotaAdapter) ListByUser(ctx context.Context, userID int64) ([]service.UserPlatformQuotaRecord, error) {
|
||||
rows, err := a.inner.ListByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.UserPlatformQuotaRecord, len(rows))
|
||||
for i, r := range rows {
|
||||
out[i] = service.UserPlatformQuotaRecord{
|
||||
UserID: r.UserID,
|
||||
Platform: r.Platform,
|
||||
DailyLimitUSD: r.DailyLimitUSD,
|
||||
WeeklyLimitUSD: r.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: r.MonthlyLimitUSD,
|
||||
DailyUsageUSD: r.DailyUsageUSD,
|
||||
WeeklyUsageUSD: r.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: r.MonthlyUsageUSD,
|
||||
DailyWindowStart: r.DailyWindowStart,
|
||||
WeeklyWindowStart: r.WeeklyWindowStart,
|
||||
MonthlyWindowStart: r.MonthlyWindowStart,
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// BulkInsertInitial 将 service.UserPlatformQuotaRecord 切片转换后调用底层 generic repo。
|
||||
func (a *genericUserPlatformQuotaAdapter) BulkInsertInitial(ctx context.Context, records []service.UserPlatformQuotaRecord) error {
|
||||
repoRecords := make([]UserPlatformQuotaRecord, len(records))
|
||||
for i, r := range records {
|
||||
repoRecords[i] = UserPlatformQuotaRecord{
|
||||
UserID: r.UserID,
|
||||
Platform: r.Platform,
|
||||
DailyLimitUSD: r.DailyLimitUSD,
|
||||
WeeklyLimitUSD: r.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: r.MonthlyLimitUSD,
|
||||
}
|
||||
}
|
||||
return a.inner.BulkInsertInitial(ctx, repoRecords)
|
||||
}
|
||||
|
||||
// UpsertForUser 全量替换(通用 adapter 实现)。
|
||||
func (a *genericUserPlatformQuotaAdapter) UpsertForUser(ctx context.Context, userID int64, records []service.UserPlatformQuotaRecord) error {
|
||||
repoRecords := toRepoRecords(records)
|
||||
return a.inner.UpsertForUser(ctx, userID, repoRecords)
|
||||
}
|
||||
|
||||
// ResetExpiredWindow 转发至 repository.ResetExpiredWindow(通用 adapter),并包装 sentinel。
|
||||
func (a *genericUserPlatformQuotaAdapter) ResetExpiredWindow(ctx context.Context, userID int64, platform string, window string, newStart time.Time) error {
|
||||
err := a.inner.ResetExpiredWindow(ctx, userID, platform, window, newStart)
|
||||
if errors.Is(err, ErrUserPlatformQuotaNotFound) {
|
||||
return fmt.Errorf("%w: %w", service.ErrUserPlatformQuotaNotFound, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// toServiceRecord 将 repository.UserPlatformQuotaRecord 转换为 service.UserPlatformQuotaRecord。
|
||||
func toServiceRecord(rec *UserPlatformQuotaRecord) *service.UserPlatformQuotaRecord {
|
||||
return &service.UserPlatformQuotaRecord{
|
||||
UserID: rec.UserID,
|
||||
Platform: rec.Platform,
|
||||
DailyLimitUSD: rec.DailyLimitUSD,
|
||||
WeeklyLimitUSD: rec.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: rec.MonthlyLimitUSD,
|
||||
DailyUsageUSD: rec.DailyUsageUSD,
|
||||
WeeklyUsageUSD: rec.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: rec.MonthlyUsageUSD,
|
||||
DailyWindowStart: rec.DailyWindowStart,
|
||||
WeeklyWindowStart: rec.WeeklyWindowStart,
|
||||
MonthlyWindowStart: rec.MonthlyWindowStart,
|
||||
}
|
||||
}
|
||||
|
||||
// toRepoRecords 将 service.UserPlatformQuotaRecord 切片转换为 repository.UserPlatformQuotaRecord(含 limit 字段,含 usage/window_start)。
|
||||
func toRepoRecords(records []service.UserPlatformQuotaRecord) []UserPlatformQuotaRecord {
|
||||
out := make([]UserPlatformQuotaRecord, len(records))
|
||||
for i, r := range records {
|
||||
out[i] = UserPlatformQuotaRecord{
|
||||
UserID: r.UserID,
|
||||
Platform: r.Platform,
|
||||
DailyLimitUSD: r.DailyLimitUSD,
|
||||
WeeklyLimitUSD: r.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: r.MonthlyLimitUSD,
|
||||
DailyUsageUSD: r.DailyUsageUSD,
|
||||
WeeklyUsageUSD: r.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: r.MonthlyUsageUSD,
|
||||
DailyWindowStart: r.DailyWindowStart,
|
||||
WeeklyWindowStart: r.WeeklyWindowStart,
|
||||
MonthlyWindowStart: r.MonthlyWindowStart,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
148
backend/internal/repository/user_platform_quota_upsert_test.go
Normal file
148
backend/internal/repository/user_platform_quota_upsert_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/userplatformquota"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpsertForUser_NewUserInsertsAllRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
daily := 10.0
|
||||
weekly := 50.0
|
||||
monthly := 200.0
|
||||
records := []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &daily, WeeklyLimitUSD: &weekly, MonthlyLimitUSD: &monthly},
|
||||
{UserID: userID, Platform: "openai", DailyLimitUSD: &daily},
|
||||
}
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, records))
|
||||
|
||||
got, err := repo.ListByUser(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, got, 2)
|
||||
}
|
||||
|
||||
func TestUpsertForUser_PartialUpdateSoftDeletesMissingPlatforms(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
d1 := 10.0
|
||||
d2 := 20.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d1},
|
||||
{UserID: userID, Platform: "openai", DailyLimitUSD: &d1},
|
||||
}))
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2},
|
||||
{UserID: userID, Platform: "gemini", DailyLimitUSD: &d1},
|
||||
}))
|
||||
|
||||
active, err := repo.ListByUser(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
platforms := map[string]float64{}
|
||||
for _, r := range active {
|
||||
require.NotNil(t, r.DailyLimitUSD)
|
||||
platforms[r.Platform] = *r.DailyLimitUSD
|
||||
}
|
||||
require.Len(t, platforms, 2)
|
||||
require.InDelta(t, 20.0, platforms["anthropic"], 1e-9)
|
||||
require.InDelta(t, 10.0, platforms["gemini"], 1e-9)
|
||||
_, openaiActive := platforms["openai"]
|
||||
require.False(t, openaiActive, "openai should be soft-deleted")
|
||||
}
|
||||
|
||||
func TestUpsertForUser_PreservesUsageAndWindowStart(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
d := 10.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
|
||||
}))
|
||||
|
||||
now := time.Date(2026, 5, 22, 10, 0, 0, 0, time.UTC)
|
||||
require.NoError(t, repo.IncrementUsageWithReset(ctx, userID, "anthropic", 3.5, now))
|
||||
|
||||
newD := 50.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &newD},
|
||||
}))
|
||||
|
||||
rec, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rec)
|
||||
require.InDelta(t, 50.0, *rec.DailyLimitUSD, 1e-9, "limit should update")
|
||||
require.InDelta(t, 3.5, rec.DailyUsageUSD, 1e-9, "usage must be preserved")
|
||||
require.NotNil(t, rec.DailyWindowStart, "window_start must be preserved")
|
||||
}
|
||||
|
||||
func TestUpsertForUser_ReactivatesSoftDeleted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
d := 10.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
|
||||
}))
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{}))
|
||||
|
||||
gone, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, gone, "anthropic should be soft-deleted (not active)")
|
||||
|
||||
d2 := 20.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d2},
|
||||
}))
|
||||
|
||||
back, err := repo.GetByUserPlatform(ctx, userID, "anthropic")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, back, "anthropic should be active again")
|
||||
require.InDelta(t, 20.0, *back.DailyLimitUSD, 1e-9)
|
||||
|
||||
allRows, err := client.UserPlatformQuota.Query().
|
||||
Where(userplatformquota.UserIDEQ(userID), userplatformquota.PlatformEQ("anthropic")).
|
||||
All(ctx)
|
||||
require.NoError(t, err)
|
||||
activeCount := 0
|
||||
for _, r := range allRows {
|
||||
if r.DeletedAt == nil {
|
||||
activeCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, activeCount, "should have exactly one active row")
|
||||
}
|
||||
|
||||
func TestUpsertForUser_EmptyClearsAll(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
userID := mustCreateUserForQuota(t, client)
|
||||
repo := NewUserPlatformQuotaRepository(client)
|
||||
|
||||
d := 10.0
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{
|
||||
{UserID: userID, Platform: "anthropic", DailyLimitUSD: &d},
|
||||
{UserID: userID, Platform: "openai", DailyLimitUSD: &d},
|
||||
}))
|
||||
|
||||
require.NoError(t, repo.UpsertForUser(ctx, userID, []UserPlatformQuotaRecord{}))
|
||||
|
||||
got, err := repo.ListByUser(ctx, userID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got)
|
||||
}
|
||||
@ -93,6 +93,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelMonitorRequestTemplateRepository,
|
||||
NewContentModerationRepository,
|
||||
NewAffiliateRepository,
|
||||
NewUserPlatformQuotaRepository, // T14: user × platform quota
|
||||
NewUserPlatformQuotaServiceAdapter, // T14: adapter → service.UserPlatformQuotaRepository
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
@ -798,6 +798,14 @@ func TestAPIContracts(t *testing.T) {
|
||||
"force_email_on_third_party_signup": false,
|
||||
"default_concurrency": 5,
|
||||
"default_balance": 1.25,
|
||||
"default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}},
|
||||
"auth_source_default_email_platform_quotas": null,
|
||||
"auth_source_default_github_platform_quotas": null,
|
||||
"auth_source_default_google_platform_quotas": null,
|
||||
"auth_source_default_linuxdo_platform_quotas": null,
|
||||
"auth_source_default_oidc_platform_quotas": null,
|
||||
"auth_source_default_wechat_platform_quotas": null,
|
||||
"auth_source_default_dingtalk_platform_quotas": null,
|
||||
"affiliate_rebate_rate": 20,
|
||||
"affiliate_rebate_freeze_hours": 0,
|
||||
"affiliate_rebate_duration_days": 0,
|
||||
@ -1025,6 +1033,14 @@ func TestAPIContracts(t *testing.T) {
|
||||
"purchase_subscription_url": "",
|
||||
"table_default_page_size": 20,
|
||||
"table_page_size_options": [10, 20, 50],
|
||||
"default_platform_quotas": {"anthropic":{"daily":null,"weekly":null,"monthly":null},"antigravity":{"daily":null,"weekly":null,"monthly":null},"gemini":{"daily":null,"weekly":null,"monthly":null},"openai":{"daily":null,"weekly":null,"monthly":null}},
|
||||
"auth_source_default_email_platform_quotas": null,
|
||||
"auth_source_default_github_platform_quotas": null,
|
||||
"auth_source_default_google_platform_quotas": null,
|
||||
"auth_source_default_linuxdo_platform_quotas": null,
|
||||
"auth_source_default_oidc_platform_quotas": null,
|
||||
"auth_source_default_wechat_platform_quotas": null,
|
||||
"auth_source_default_dingtalk_platform_quotas": null,
|
||||
"custom_menu_items": [],
|
||||
"custom_endpoints": [],
|
||||
"default_concurrency": 0,
|
||||
|
||||
@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
toucher := &recordingActivityToucher{}
|
||||
|
||||
|
||||
@ -241,6 +241,9 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
|
||||
users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
|
||||
users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency)
|
||||
users.GET("/:id/platform-quotas", h.Admin.User.GetUserPlatformQuotas)
|
||||
users.PUT("/:id/platform-quotas", h.Admin.User.UpdateUserPlatformQuotas)
|
||||
users.POST("/:id/platform-quotas/reset", h.Admin.User.ResetUserPlatformQuotaWindow)
|
||||
|
||||
// User attribute values
|
||||
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
|
||||
|
||||
@ -32,6 +32,7 @@ func RegisterUserRoutes(
|
||||
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
||||
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
||||
user.GET("/api-keys/:id/usage/daily", h.Usage.GetMyAPIKeyDailyUsage)
|
||||
user.GET("/platform-quotas", h.User.GetMyPlatformQuotas)
|
||||
|
||||
// 通知邮箱管理
|
||||
notifyEmail := user.Group("/notify-email")
|
||||
|
||||
@ -459,6 +459,22 @@ func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID
|
||||
panic("unexpected InvalidateAPIKeyRateLimit call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
|
||||
panic("unexpected GetUserPlatformQuotaCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
|
||||
panic("unexpected SetUserPlatformQuotaCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
|
||||
panic("unexpected DeleteUserPlatformQuotaCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
panic("unexpected IncrUserPlatformQuotaUsageCache call")
|
||||
}
|
||||
|
||||
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
|
||||
t.Helper()
|
||||
calls := make([]subscriptionInvalidateCall, 0, expected)
|
||||
|
||||
@ -189,6 +189,8 @@ func (s *AuthService) createEmailOAuthUser(ctx context.Context, email, username,
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, providerType, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
|
||||
88
backend/internal/service/auth_email_oauth_auto_test.go
Normal file
88
backend/internal/service/auth_email_oauth_auto_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newEmailOAuthAutoAuthService(
|
||||
userRepo UserRepository,
|
||||
settings map[string]string,
|
||||
quotaRepo UserPlatformQuotaRepository,
|
||||
) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient — nil, updateUserSignupSource early return
|
||||
userRepo,
|
||||
nil, // redeemRepo — invitationCode="" 时不触发
|
||||
&refreshTokenCacheStub{},
|
||||
cfg,
|
||||
settingService,
|
||||
nil, // emailService
|
||||
nil, // turnstileService
|
||||
nil, // emailQueueService
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner — nil, assignSubscriptions early return
|
||||
nil, // affiliateService — nil, bindOAuthAffiliate early return
|
||||
quotaRepo,
|
||||
)
|
||||
}
|
||||
|
||||
func TestEmailOAuthAuto_SnapshotsPlatformQuotaDefaults(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 88}
|
||||
quotaRepo := &userPlatformQuotaRepoStub{}
|
||||
|
||||
svc := newEmailOAuthAutoAuthService(
|
||||
userRepo,
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultPlatformQuotas: `{"gemini": {"monthly": 100.0}}`,
|
||||
},
|
||||
quotaRepo,
|
||||
)
|
||||
|
||||
user, err := svc.createEmailOAuthUser(
|
||||
context.Background(),
|
||||
"newoauth@example.com",
|
||||
"newoauth",
|
||||
"github",
|
||||
"", // invitationCode
|
||||
"", // affiliateCode
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(88), user.ID)
|
||||
|
||||
require.Len(t, quotaRepo.bulkInsertCalls, 1, "createEmailOAuthUser must snapshot platform quotas via BulkInsertInitial")
|
||||
|
||||
records := quotaRepo.bulkInsertCalls[0]
|
||||
var geminiRecord *UserPlatformQuotaRecord
|
||||
for i := range records {
|
||||
if records[i].Platform == "gemini" {
|
||||
geminiRecord = &records[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, geminiRecord, "expected gemini platform record")
|
||||
require.NotNil(t, geminiRecord.MonthlyLimitUSD)
|
||||
require.InDelta(t, 100.0, *geminiRecord.MonthlyLimitUSD, 0.0001)
|
||||
}
|
||||
@ -283,6 +283,8 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
||||
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -112,6 +112,7 @@ func newOAuthEmailFlowAuthService(
|
||||
refreshTokenCache RefreshTokenCache,
|
||||
settings map[string]string,
|
||||
emailCache EmailCache,
|
||||
quotaRepo UserPlatformQuotaRepository, // 新增
|
||||
) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
@ -142,6 +143,7 @@ func newOAuthEmailFlowAuthService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
quotaRepo, // 替换原来的 nil
|
||||
)
|
||||
}
|
||||
|
||||
@ -175,6 +177,7 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
nil,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
@ -215,6 +218,7 @@ func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *tes
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
nil,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
@ -274,6 +278,7 @@ func TestRegisterOAuthEmailAccountKeepsGitHubAndGoogleSignupSource(t *testing.T)
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
nil,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
@ -313,6 +318,7 @@ func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
nil,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
@ -360,6 +366,7 @@ func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T)
|
||||
SettingKeyInvitationCodeEnabled: "true",
|
||||
},
|
||||
&emailCacheStub{},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
|
||||
@ -382,6 +389,7 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
},
|
||||
&emailCacheStub{},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
|
||||
@ -389,3 +397,54 @@ func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "delete created oauth user")
|
||||
}
|
||||
|
||||
func TestFinalizeOAuthEmailAccount_SnapshotsPlatformQuotaDefaults(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 99}
|
||||
quotaRepo := &userPlatformQuotaRepoStub{}
|
||||
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
nil,
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 5.5}}`,
|
||||
},
|
||||
&emailCacheStub{},
|
||||
quotaRepo,
|
||||
)
|
||||
|
||||
user := &User{
|
||||
ID: 99,
|
||||
Email: "newuser@example.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
SignupSource: "oidc",
|
||||
}
|
||||
|
||||
err := authService.FinalizeOAuthEmailAccount(
|
||||
context.Background(),
|
||||
user,
|
||||
"",
|
||||
"oidc",
|
||||
"",
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, quotaRepo.bulkInsertCalls, 1, "snapshotPlatformQuotaDefaults must call BulkInsertInitial once on successful OAuth signup")
|
||||
|
||||
records := quotaRepo.bulkInsertCalls[0]
|
||||
var anthropicRecord *UserPlatformQuotaRecord
|
||||
for i := range records {
|
||||
if records[i].Platform == "anthropic" {
|
||||
anthropicRecord = &records[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, anthropicRecord, "expected anthropic platform record")
|
||||
require.Equal(t, int64(99), anthropicRecord.UserID)
|
||||
require.NotNil(t, anthropicRecord.DailyLimitUSD)
|
||||
require.InDelta(t, 5.5, *anthropicRecord.DailyLimitUSD, 0.0001)
|
||||
}
|
||||
|
||||
@ -62,18 +62,19 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
entClient *dbent.Client
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
affiliateService *AffiliateService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
entClient *dbent.Client
|
||||
userRepo UserRepository
|
||||
redeemRepo RedeemCodeRepository
|
||||
refreshTokenCache RefreshTokenCache
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
affiliateService *AffiliateService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
type DefaultSubscriptionAssigner interface {
|
||||
@ -81,9 +82,10 @@ type DefaultSubscriptionAssigner interface {
|
||||
}
|
||||
|
||||
type signupGrantPlan struct {
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Subscriptions []DefaultSubscriptionSetting
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Subscriptions []DefaultSubscriptionSetting
|
||||
PlatformQuotas map[string]*DefaultPlatformQuotaSetting
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@ -100,20 +102,22 @@ func NewAuthService(
|
||||
promoService *PromoService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
affiliateService *AffiliateService,
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
entClient: entClient,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
affiliateService: affiliateService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
entClient: entClient,
|
||||
userRepo: userRepo,
|
||||
redeemRepo: redeemRepo,
|
||||
refreshTokenCache: refreshTokenCache,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
affiliateService: affiliateService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,6 +230,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, "email", true)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
if s.affiliateService != nil {
|
||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
|
||||
@ -535,6 +541,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
|
||||
@ -685,6 +693,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
}
|
||||
} else {
|
||||
@ -703,6 +713,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||
// snapshot user × platform quota(fail-open)
|
||||
_ = s.snapshotPlatformQuotaDefaults(ctx, user.ID, &grantPlan)
|
||||
s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
@ -764,18 +776,39 @@ func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource s
|
||||
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
|
||||
|
||||
// ============ 全局 quota 装载(必须在 ResolveAuthSourceGrantSettings 之前) ============
|
||||
// 无论 auth source 是否 enabled,全局层都要先装载,确保 !enabled 早退路径也携带全局 quota。
|
||||
if quotas, err := s.settingService.GetDefaultPlatformQuotas(ctx); err == nil {
|
||||
plan.PlatformQuotas = quotas
|
||||
} else {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Warning: load default platform quotas failed: %v (fail-open)", err)
|
||||
}
|
||||
// ============================================================================================
|
||||
|
||||
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
|
||||
return plan
|
||||
}
|
||||
if !enabled {
|
||||
return plan
|
||||
return plan // plan.PlatformQuotas 已含全局层
|
||||
}
|
||||
|
||||
plan.Balance = resolved.Balance
|
||||
plan.Concurrency = resolved.Concurrency
|
||||
plan.Subscriptions = resolved.Subscriptions
|
||||
|
||||
// ============ auth source quota merge(仅在 enabled 分支内) ============
|
||||
asQuotas := s.settingService.GetAuthSourcePlatformQuotas(ctx, signupSource)
|
||||
if plan.PlatformQuotas != nil {
|
||||
for platform, patch := range asQuotas {
|
||||
if dst := plan.PlatformQuotas[platform]; dst != nil {
|
||||
mergePlatformQuotaDefaults(dst, patch)
|
||||
}
|
||||
}
|
||||
}
|
||||
// ==============================================================================
|
||||
|
||||
return plan
|
||||
}
|
||||
|
||||
@ -1586,3 +1619,29 @@ func resolvedTokenVersion(user *User) int64 {
|
||||
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
|
||||
return user.TokenVersion ^ fingerprint
|
||||
}
|
||||
|
||||
// snapshotPlatformQuotaDefaults 把 plan.PlatformQuotas(4 platform × 3 window)以
|
||||
// BulkInsertInitial 形式写入 user_platform_quotas 表。失败 fail-open(仅 warn log)。
|
||||
func (s *AuthService) snapshotPlatformQuotaDefaults(ctx context.Context, userID int64, plan *signupGrantPlan) error {
|
||||
if s.userPlatformQuotaRepo == nil || plan == nil || len(plan.PlatformQuotas) == 0 {
|
||||
return nil
|
||||
}
|
||||
records := make([]UserPlatformQuotaRecord, 0, len(plan.PlatformQuotas))
|
||||
for platform, q := range plan.PlatformQuotas {
|
||||
rec := UserPlatformQuotaRecord{
|
||||
UserID: userID,
|
||||
Platform: platform,
|
||||
}
|
||||
if q != nil {
|
||||
rec.DailyLimitUSD = q.DailyLimitUSD
|
||||
rec.WeeklyLimitUSD = q.WeeklyLimitUSD
|
||||
rec.MonthlyLimitUSD = q.MonthlyLimitUSD
|
||||
}
|
||||
records = append(records, rec)
|
||||
}
|
||||
if err := s.userPlatformQuotaRepo.BulkInsertInitial(ctx, records); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Warning: snapshot platform quota failed user=%d: %v (fail-open)", userID, err)
|
||||
return nil // fail-open:返回 nil,让调用方继续
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
||||
}
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
|
||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil, nil)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, cache)
|
||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
|
||||
ID: 41,
|
||||
@ -810,8 +810,8 @@ func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, t
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
|
||||
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
@ -820,8 +820,12 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil }
|
||||
func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
|
||||
@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
values: settings,
|
||||
}, cfg)
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil, nil)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
|
||||
157
backend/internal/service/auth_service_platform_quota_test.go
Normal file
157
backend/internal/service/auth_service_platform_quota_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeInsertRecorder 记录 BulkInsertInitial 调用,实现 UserPlatformQuotaRepository port。
|
||||
type fakeInsertRecorder struct {
|
||||
records []UserPlatformQuotaRecord
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) BulkInsertInitial(_ context.Context, recs []UserPlatformQuotaRecord) error {
|
||||
if f.err != nil {
|
||||
return f.err
|
||||
}
|
||||
f.records = append(f.records, recs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeInsertRecorder) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSnapshotPlatformQuotaDefaults_PassesToRepoBulkInsert(t *testing.T) {
|
||||
fakeRepo := &fakeInsertRecorder{}
|
||||
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
|
||||
|
||||
five := 5.0
|
||||
plan := &signupGrantPlan{
|
||||
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
"openai": {},
|
||||
"gemini": {},
|
||||
"antigravity": {},
|
||||
},
|
||||
}
|
||||
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 999, plan); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(fakeRepo.records) != 4 {
|
||||
t.Errorf("expected 4 records, got %d", len(fakeRepo.records))
|
||||
}
|
||||
found := false
|
||||
for _, r := range fakeRepo.records {
|
||||
if r.UserID == 999 && r.Platform == "anthropic" && r.DailyLimitUSD != nil && *r.DailyLimitUSD == 5 {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("anthropic daily = 5 not snapshotted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotPlatformQuotaDefaults_NilPlanIsNoop(t *testing.T) {
|
||||
fakeRepo := &fakeInsertRecorder{}
|
||||
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
|
||||
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, nil); err != nil {
|
||||
t.Errorf("nil plan should be noop, got %v", err)
|
||||
}
|
||||
if len(fakeRepo.records) != 0 {
|
||||
t.Errorf("expected no records, got %d", len(fakeRepo.records))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotPlatformQuotaDefaults_RepoErrorFailsOpen(t *testing.T) {
|
||||
fakeRepo := &fakeInsertRecorder{err: fmt.Errorf("db down")}
|
||||
s := &AuthService{userPlatformQuotaRepo: fakeRepo}
|
||||
five := 5.0
|
||||
plan := &signupGrantPlan{
|
||||
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{
|
||||
"anthropic": {DailyLimitUSD: &five},
|
||||
},
|
||||
}
|
||||
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil {
|
||||
t.Errorf("fail-open: expected nil even on repo error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotPlatformQuotaDefaults_NilRepoIsNoop(t *testing.T) {
|
||||
s := &AuthService{userPlatformQuotaRepo: nil}
|
||||
five := 5.0
|
||||
plan := &signupGrantPlan{
|
||||
PlatformQuotas: map[string]*DefaultPlatformQuotaSetting{"a": {DailyLimitUSD: &five}},
|
||||
}
|
||||
if err := s.snapshotPlatformQuotaDefaults(context.Background(), 1, plan); err != nil {
|
||||
t.Errorf("nil repo should be noop, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveSignupGrantPlan 测试:依赖完整的 AuthService 构造,需要 SettingService(含 settingRepoStub)。
|
||||
// settingRepoStub 已在 auth_service_register_test.go 中定义,同 package 可直接使用。
|
||||
func TestResolveSignupGrantPlan_GlobalQuotaLoadedBeforeAuthSource(t *testing.T) {
|
||||
// 全局 quota JSON key(新格式)
|
||||
settings := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultPlatformQuotas: `{
|
||||
"anthropic": {"daily": 10, "weekly": 50, "monthly": 200},
|
||||
"openai": {"daily": 5, "weekly": 25, "monthly": 100},
|
||||
"gemini": {"daily": 5, "weekly": 25, "monthly": 100},
|
||||
"antigravity": {"daily": 5, "weekly": 25, "monthly": 100}
|
||||
}`,
|
||||
}
|
||||
svc := newAuthService(nil, settings, nil, nil)
|
||||
plan := svc.resolveSignupGrantPlan(context.Background(), "email")
|
||||
if plan.PlatformQuotas == nil {
|
||||
t.Fatal("expected PlatformQuotas to be non-nil after loading global quota KVs")
|
||||
}
|
||||
q := plan.PlatformQuotas["anthropic"]
|
||||
if q == nil {
|
||||
t.Fatal("expected anthropic quota to be set")
|
||||
}
|
||||
if q.DailyLimitUSD == nil || *q.DailyLimitUSD != 10 {
|
||||
t.Errorf("expected anthropic daily=10, got %v", q.DailyLimitUSD)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota 验证 P1 约束:
|
||||
// !enabled 早退路径仍携带全局 quota(GetDefaultPlatformQuotas 在 ResolveAuthSourceGrantSettings 之前)。
|
||||
func TestResolveSignupGrantPlan_DisabledAuthSourceStillCarriesGlobalQuota(t *testing.T) {
|
||||
settings := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
// auth source 不配置(=> !enabled 路径)
|
||||
SettingKeyDefaultPlatformQuotas: `{"anthropic": {"daily": 10, "weekly": 50, "monthly": 200}}`,
|
||||
}
|
||||
svc := newAuthService(nil, settings, nil, nil)
|
||||
plan := svc.resolveSignupGrantPlan(context.Background(), "email")
|
||||
// !enabled 路径:plan.PlatformQuotas 应已含全局层(不是 nil)
|
||||
if plan.PlatformQuotas == nil {
|
||||
t.Fatal("P1 violated: PlatformQuotas is nil even with global quota KVs set")
|
||||
}
|
||||
// P1 核心断言:disabled auth source 路径不能丢失全局 quota
|
||||
if _, ok := plan.PlatformQuotas["anthropic"]; !ok {
|
||||
t.Error("P1 violated: disabled auth source path dropped global platform quota")
|
||||
}
|
||||
}
|
||||
@ -73,6 +73,38 @@ type defaultSubscriptionAssignerStub struct {
|
||||
|
||||
type refreshTokenCacheStub struct{}
|
||||
|
||||
type userPlatformQuotaRepoStub struct {
|
||||
bulkInsertCalls [][]UserPlatformQuotaRecord
|
||||
bulkInsertErr error
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) BulkInsertInitial(_ context.Context, records []UserPlatformQuotaRecord) error {
|
||||
cloned := make([]UserPlatformQuotaRecord, len(records))
|
||||
copy(cloned, records)
|
||||
s.bulkInsertCalls = append(s.bulkInsertCalls, cloned)
|
||||
return s.bulkInsertErr
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) GetByUserPlatform(context.Context, int64, string) (*UserPlatformQuotaRecord, error) {
|
||||
panic("unexpected GetByUserPlatform call")
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) ListByUser(context.Context, int64) ([]UserPlatformQuotaRecord, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) IncrementUsageWithReset(context.Context, int64, string, float64, time.Time) error {
|
||||
panic("unexpected IncrementUsageWithReset call")
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) UpsertForUser(context.Context, int64, []UserPlatformQuotaRecord) error {
|
||||
panic("unexpected UpsertForUser call")
|
||||
}
|
||||
|
||||
func (s *userPlatformQuotaRepoStub) ResetExpiredWindow(context.Context, int64, string, string, time.Time) error {
|
||||
panic("unexpected ResetExpiredWindow call")
|
||||
}
|
||||
|
||||
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
if input != nil {
|
||||
s.calls = append(s.calls, *input)
|
||||
@ -178,7 +210,7 @@ func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int6
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache, quotaRepo UserPlatformQuotaRepository) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
@ -213,6 +245,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
nil, // affiliateService
|
||||
quotaRepo,
|
||||
)
|
||||
}
|
||||
|
||||
@ -220,7 +253,7 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "false",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
@ -229,19 +262,62 @@ func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, nil, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_SnapshotsPlatformQuotaDefaults(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 77}
|
||||
quotaRepo := &userPlatformQuotaRepoStub{}
|
||||
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultPlatformQuotas: `{"openai": {"weekly": 12.34}}`,
|
||||
}, nil, quotaRepo)
|
||||
|
||||
_, user, err := service.Register(context.Background(), "newuser@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
|
||||
require.Len(t, quotaRepo.bulkInsertCalls, 1)
|
||||
|
||||
records := quotaRepo.bulkInsertCalls[0]
|
||||
var openaiRecord *UserPlatformQuotaRecord
|
||||
for i := range records {
|
||||
if records[i].Platform == "openai" {
|
||||
openaiRecord = &records[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, openaiRecord, "expected openai platform record")
|
||||
require.Equal(t, int64(77), openaiRecord.UserID)
|
||||
require.NotNil(t, openaiRecord.WeeklyLimitUSD)
|
||||
require.InDelta(t, 12.34, *openaiRecord.WeeklyLimitUSD, 0.0001)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DoesNotSnapshotOnDisabled(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
quotaRepo := &userPlatformQuotaRepoStub{}
|
||||
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "false",
|
||||
}, nil, quotaRepo)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
|
||||
require.Empty(t, quotaRepo.bulkInsertCalls, "registration rejected before user creation must not snapshot")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
|
||||
@ -254,7 +330,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
}, cache, nil)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
@ -268,7 +344,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
}, cache, nil)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
@ -279,7 +355,7 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
@ -289,7 +365,7 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
@ -299,7 +375,7 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
@ -310,7 +386,7 @@ func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@other.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
@ -327,7 +403,7 @@ func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, user, err := service.Register(context.Background(), "user@example.com", "password")
|
||||
require.NoError(t, err)
|
||||
@ -340,7 +416,7 @@ func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
err := service.SendVerifyCode(context.Background(), "user@other.com")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
@ -354,7 +430,7 @@ func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
@ -365,7 +441,7 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
@ -376,7 +452,7 @@ func TestAuthService_Register_Success(t *testing.T) {
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
@ -394,7 +470,7 @@ func TestAuthService_Register_Success(t *testing.T) {
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, nil, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
@ -436,7 +512,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
service := newAuthService(repo, nil, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
@ -453,7 +529,7 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service := newAuthService(&userRepoStub{}, nil, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 0
|
||||
|
||||
@ -461,7 +537,7 @@ func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T)
|
||||
}
|
||||
|
||||
func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service := newAuthService(&userRepoStub{}, nil, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 90
|
||||
|
||||
@ -469,7 +545,7 @@ func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service := newAuthService(&userRepoStub{}, nil, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 0
|
||||
|
||||
@ -494,7 +570,7 @@ func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service := newAuthService(&userRepoStub{}, nil, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 90
|
||||
|
||||
@ -525,7 +601,7 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
|
||||
@ -549,7 +625,7 @@ func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *tes
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
|
||||
@ -572,7 +648,7 @@ func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *tes
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "88",
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "email-global@test.com", "password")
|
||||
@ -595,7 +671,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
|
||||
@ -618,7 +694,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
|
||||
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
|
||||
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
@ -654,7 +730,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
|
||||
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
|
||||
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
service.refreshTokenCache = &refreshTokenCacheStub{}
|
||||
|
||||
@ -677,7 +753,7 @@ func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.Ding
|
||||
DingTalk: dtCfg,
|
||||
}
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil)
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。
|
||||
|
||||
@ -55,6 +55,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
nil, // affiliateService
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -11,18 +11,31 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
// errBillingCacheUnavailable 内部哨兵:用于 quota 校验路径在 cache==nil 时
|
||||
// 与"Redis 故障"走同一条 fail-open + DB 一次性检查的分支。
|
||||
var errBillingCacheUnavailable = fmt.Errorf("billing cache unavailable")
|
||||
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
// RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
|
||||
ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
|
||||
ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
|
||||
|
||||
// user × platform quota(HTTP 429 Too Many Requests + Retry-After header)。
|
||||
// 选用 429 而非 403:限额耗尽属于"暂时性资源用尽,重试可恢复"的场景(RFC 6585),
|
||||
// 大量 SDK(如 OpenAI 兼容客户端)只对 429 触发自动退避并读取 Retry-After,
|
||||
// 用 403 会被视为"权限不足,重试无意义"导致客户端直接报错且不退避。
|
||||
ErrUserPlatformDailyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_DAILY_QUOTA_EXHAUSTED", "Daily usage quota exhausted for this platform.")
|
||||
ErrUserPlatformWeeklyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_WEEKLY_QUOTA_EXHAUSTED", "Weekly usage quota exhausted for this platform.")
|
||||
ErrUserPlatformMonthlyQuotaExhausted = infraerrors.TooManyRequests("USER_PLATFORM_MONTHLY_QUOTA_EXHAUSTED", "Monthly usage quota exhausted for this platform.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
@ -94,6 +107,7 @@ type BillingCacheService struct {
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
@ -101,6 +115,7 @@ type BillingCacheService struct {
|
||||
cacheWriteMu sync.RWMutex
|
||||
stopped atomic.Bool
|
||||
balanceLoadSF singleflight.Group
|
||||
quotaLoadSF singleflight.Group
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
@ -117,6 +132,7 @@ func NewBillingCacheService(
|
||||
userRPMCache UserRPMCache,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cfg *config.Config,
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||||
) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
@ -126,6 +142,7 @@ func NewBillingCacheService(
|
||||
userRPMCache: userRPMCache,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cfg: cfg,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
@ -655,6 +672,30 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co
|
||||
})
|
||||
}
|
||||
|
||||
// IncrementUserPlatformQuotaUsage 同步累加 user × platform usage 到 Redis 缓存。
|
||||
//
|
||||
// 设计:同步写入而非异步入队。同步写确保下次 preflight 立即看到最新 usage,
|
||||
// 把 TOCTOU 超支窗口限制在并发 in-flight 请求数量内(而非随时间无限累积)。
|
||||
// 写延迟通常 < 1ms(本地 Redis),换取 quota 视图实时性的取舍合理。
|
||||
//
|
||||
// Redis 写失败用 ALERT 级 log;DB 持久化由 caller 单独 goroutine 兜底(gateway_service.go)。
|
||||
func (s *BillingCacheService) IncrementUserPlatformQuotaUsage(userID int64, platform string, cost float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if platform == "" || cost <= 0 {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||||
if err := s.cache.IncrUserPlatformQuotaUsageCache(ctx, userID, platform, cost, ttl); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache",
|
||||
"ALERT: incr user platform quota cache failed user=%d platform=%s cost=%f: %v",
|
||||
userID, platform, cost, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
@ -662,7 +703,8 @@ func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, co
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// platform 为请求的目标平台(如 "anthropic"),传空串 "" 时跳过 user × platform quota 检查。
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription, platform string) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
@ -684,6 +726,13 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
|
||||
}
|
||||
}
|
||||
|
||||
// user × platform quota 仅在 standard(余额)模式生效;订阅模式豁免
|
||||
if !isSubscriptionMode {
|
||||
if err := s.checkUserPlatformQuotaEligibility(ctx, user.ID, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check API Key rate limits (applies to both billing modes)
|
||||
if apiKey != nil && apiKey.HasRateLimits() {
|
||||
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
|
||||
@ -975,3 +1024,257 @@ func circuitStateString(state billingCircuitBreakerState) string {
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// checkUserPlatformQuotaEligibility 在 standard 模式下检查 user × platform 日/周/月 quota。
|
||||
// 返回 nil = 允许;返回 ErrUserPlatform{Daily/Weekly/Monthly}QuotaExhausted = 拒绝(带 window_resets_at metadata)。
|
||||
// checkUserPlatformQuotaEligibility 检查用户在指定平台的 USD 配额。
|
||||
//
|
||||
// 流程(Redis-first / DB-fallback):
|
||||
// 1. 先读 Redis cache;若命中且 SchemaVersion==1,直接用 entry 中的 limits 和 window_start 做校验,
|
||||
// 免除 DB 查询。
|
||||
// 2. cache MISS 或旧版 entry(SchemaVersion==0)→ 查 DB 回填完整 entry(含 limits/window_start)。
|
||||
// 3. Redis 故障(err != nil)→ fail-open,查 DB 做一次性检查,不回填。
|
||||
func (s *BillingCacheService) checkUserPlatformQuotaEligibility(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
platform string,
|
||||
) error {
|
||||
if platform == "" || s.userPlatformQuotaRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cache 未配置(如简化部署 / 单测路径)→ 直接走 DB 查询,避免 nil panic。
|
||||
// 其他 check* 方法(balance/subscription/rate-limit)也有类似守卫。
|
||||
var (
|
||||
entry *UserPlatformQuotaCacheEntry
|
||||
ok bool
|
||||
cacheErr error
|
||||
)
|
||||
if s.cache != nil {
|
||||
entry, ok, cacheErr = s.cache.GetUserPlatformQuotaCache(ctx, userID, platform)
|
||||
} else {
|
||||
// 标记为"cache 故障"分支:跳过 HIT 路径、不回填、走 DB 一次性检查
|
||||
cacheErr = errBillingCacheUnavailable
|
||||
}
|
||||
|
||||
// --- cache HIT with current schema → 直接用 entry,不查 DB ---
|
||||
if cacheErr == nil && ok && entry != nil && entry.SchemaVersion == UserPlatformQuotaCacheSchemaV1 {
|
||||
now := time.Now()
|
||||
dailyUsage := entry.DailyUsageUSD
|
||||
weeklyUsage := entry.WeeklyUsageUSD
|
||||
monthlyUsage := entry.MonthlyUsageUSD
|
||||
// 若窗口已更新(DB 已重置但 cache 尚未失效),将对应 usage 清零再做比较,
|
||||
// 同时记录新窗口起点用于后续刷新 cache entry。
|
||||
// 本次请求用本地清零值继续判断;DB 层 IncrementUsageWithReset 已有窗口自愈能力,
|
||||
// 持久化数据始终正确。
|
||||
windowExpired := false
|
||||
newDailyStart := entry.DailyWindowStart
|
||||
newWeeklyStart := entry.WeeklyWindowStart
|
||||
newMonthlyStart := entry.MonthlyWindowStart
|
||||
if quotaWindowExpired(entry.DailyWindowStart, timezone.StartOfDay(now)) {
|
||||
dailyUsage = 0
|
||||
windowExpired = true
|
||||
dayStart := timezone.StartOfDay(now)
|
||||
newDailyStart = &dayStart
|
||||
}
|
||||
if quotaWindowExpired(entry.WeeklyWindowStart, timezone.StartOfWeek(now)) {
|
||||
weeklyUsage = 0
|
||||
windowExpired = true
|
||||
weekStart := timezone.StartOfWeek(now)
|
||||
newWeeklyStart = &weekStart
|
||||
}
|
||||
if monthlyQuotaWindowExpired(entry.MonthlyWindowStart, now) {
|
||||
monthlyUsage = 0
|
||||
windowExpired = true
|
||||
monthStart := now
|
||||
newMonthlyStart = &monthStart
|
||||
}
|
||||
// 检测到任意窗口过期:用 reset 后的 entry 覆盖 Redis(而非 Delete)。
|
||||
// 旧实现 Delete 后,期间到达的 IncrUserPlatformQuotaUsage 调用让 Lua 看到
|
||||
// EXISTS=0 直接 return 0,并发请求的 cost 永久丢失,直到下次 cache MISS 回填。
|
||||
// 改为 SetCache 原子覆盖:key 不断链,Lua INCR 可在新窗口 entry 上正确累加。
|
||||
// 超时 50ms:覆盖正常路径与可接受抖动;Redis 异常时 hot path 不阻塞超过此值。
|
||||
// 用 context.Background()+短超时,避免请求 ctx 取消导致刷新丢失。
|
||||
// 显式 setCancel()(而非 defer):缩短 context 生命周期,避免 defer 延迟到函数返回。
|
||||
if windowExpired && s.cache != nil {
|
||||
refreshed := &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: dailyUsage,
|
||||
WeeklyUsageUSD: weeklyUsage,
|
||||
MonthlyUsageUSD: monthlyUsage,
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
DailyLimitUSD: entry.DailyLimitUSD,
|
||||
WeeklyLimitUSD: entry.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: entry.MonthlyLimitUSD,
|
||||
DailyWindowStart: newDailyStart,
|
||||
WeeklyWindowStart: newWeeklyStart,
|
||||
MonthlyWindowStart: newMonthlyStart,
|
||||
}
|
||||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, refreshed, ttl); setErr != nil {
|
||||
logger.LegacyPrintf("service.billing_cache",
|
||||
"Warning: refresh expired user platform quota cache failed user=%d platform=%s: %v",
|
||||
userID, platform, setErr)
|
||||
}
|
||||
setCancel()
|
||||
}
|
||||
if entry.DailyLimitUSD != nil && dailyUsage >= *entry.DailyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||||
}
|
||||
if entry.WeeklyLimitUSD != nil && weeklyUsage >= *entry.WeeklyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||||
}
|
||||
if entry.MonthlyLimitUSD != nil && monthlyUsage >= *entry.MonthlyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(entry.MonthlyWindowStart, now))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- cache MISS、旧版 entry 或 Redis 故障 → 查 DB(singleflight 合并并发回源)---
|
||||
// 使用 DoChan 而非 Do:avoid sharing the first caller's ctx among all dedupe followers.
|
||||
// 若第一个 caller 的 ctx 被取消(客户端断连),后续 caller 不受影响,仍由各自 ctx 控制超时。
|
||||
sfKey := strconv.FormatInt(userID, 10) + ":" + platform
|
||||
ch := s.quotaLoadSF.DoChan(sfKey, func() (any, error) {
|
||||
// 子查询用 detached context + 短超时,独立于任何 caller 的请求 ctx,
|
||||
// 防止"第一个 caller ctx 取消"使所有 follower 一起 fail。
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer bgCancel()
|
||||
return s.userPlatformQuotaRepo.GetByUserPlatform(bgCtx, userID, platform)
|
||||
})
|
||||
var (
|
||||
v any
|
||||
dbErr error
|
||||
)
|
||||
select {
|
||||
case res := <-ch:
|
||||
v, dbErr = res.Val, res.Err
|
||||
case <-ctx.Done():
|
||||
// 当前 caller 的 ctx 被取消:fail-open,不阻断 (此请求已无意义)。
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: user platform quota check ctx cancelled user=%d platform=%s: %v (fail-open)", userID, platform, ctx.Err())
|
||||
return nil
|
||||
}
|
||||
if dbErr != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: load user platform quota failed user=%d platform=%s: %v (fail-open)", userID, platform, dbErr)
|
||||
return nil
|
||||
}
|
||||
rec, _ := v.(*UserPlatformQuotaRecord)
|
||||
if rec == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
dailyUsage := rec.DailyUsageUSD
|
||||
weeklyUsage := rec.WeeklyUsageUSD
|
||||
monthlyUsage := rec.MonthlyUsageUSD
|
||||
if quotaWindowExpired(rec.DailyWindowStart, timezone.StartOfDay(now)) {
|
||||
dailyUsage = 0
|
||||
}
|
||||
if quotaWindowExpired(rec.WeeklyWindowStart, timezone.StartOfWeek(now)) {
|
||||
weeklyUsage = 0
|
||||
}
|
||||
if monthlyQuotaWindowExpired(rec.MonthlyWindowStart, now) {
|
||||
monthlyUsage = 0
|
||||
}
|
||||
|
||||
// Redis 故障时 fail-open:不回填,直接用 DB 数据做一次性检查
|
||||
if cacheErr != nil {
|
||||
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||||
}
|
||||
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||||
}
|
||||
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cache MISS 或旧版 entry → 回填完整 entry(含 limits 和 window_start)
|
||||
newEntry := &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: dailyUsage,
|
||||
WeeklyUsageUSD: weeklyUsage,
|
||||
MonthlyUsageUSD: monthlyUsage,
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
DailyLimitUSD: rec.DailyLimitUSD,
|
||||
WeeklyLimitUSD: rec.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: rec.MonthlyLimitUSD,
|
||||
DailyWindowStart: rec.DailyWindowStart,
|
||||
WeeklyWindowStart: rec.WeeklyWindowStart,
|
||||
MonthlyWindowStart: rec.MonthlyWindowStart,
|
||||
}
|
||||
if s.cache != nil {
|
||||
ttl := time.Duration(s.cfg.Billing.UserPlatformQuotaCacheTTLSeconds) * time.Second
|
||||
// 与 HIT 过期回填路径(上文 SetCache 调用)保持一致:用 context.Background()+50ms,
|
||||
// 避免请求 ctx 提前取消(客户端断连/上游超时)导致 cache 回填失败,
|
||||
// 让下一次 preflight 仍然 MISS 并击穿到 DB(高并发下增大 DB 压力)。
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
if setErr := s.cache.SetUserPlatformQuotaCache(setCtx, userID, platform, newEntry, ttl); setErr != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: set user platform quota cache failed user=%d platform=%s: %v", userID, platform, setErr)
|
||||
}
|
||||
setCancel()
|
||||
}
|
||||
|
||||
if rec.DailyLimitUSD != nil && dailyUsage >= *rec.DailyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformDailyQuotaExhausted, nextDailyReset(now))
|
||||
}
|
||||
if rec.WeeklyLimitUSD != nil && weeklyUsage >= *rec.WeeklyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformWeeklyQuotaExhausted, nextWeeklyReset(now))
|
||||
}
|
||||
if rec.MonthlyLimitUSD != nil && monthlyUsage >= *rec.MonthlyLimitUSD {
|
||||
return withWindowResetsMetadata(ErrUserPlatformMonthlyQuotaExhausted, nextMonthlyResetFrom(rec.MonthlyWindowStart, now))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// withWindowResetsMetadata 给 quota error 附加 window_resets_at metadata(RFC3339)。
|
||||
func withWindowResetsMetadata(err error, resetAt time.Time) error {
|
||||
appErr, ok := err.(*infraerrors.ApplicationError)
|
||||
if !ok || appErr == nil {
|
||||
return err
|
||||
}
|
||||
return appErr.WithMetadata(map[string]string{
|
||||
"window_resets_at": resetAt.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// nextDailyReset 计算下一个日窗口起点(次日全局时区 0 点)。
|
||||
// 必须与 timezone.StartOfDay 同口径,否则 Retry-After 会偏差。
|
||||
func nextDailyReset(now time.Time) time.Time {
|
||||
return timezone.StartOfDay(now).AddDate(0, 0, 1)
|
||||
}
|
||||
|
||||
// nextWeeklyReset 计算下一个周窗口起点(下周一全局时区 0 点)。
|
||||
// 必须与 timezone.StartOfWeek 同口径,否则 Retry-After 会偏差。
|
||||
func nextWeeklyReset(now time.Time) time.Time {
|
||||
return timezone.StartOfWeek(now).AddDate(0, 0, 7)
|
||||
}
|
||||
|
||||
// nextMonthlyResetFrom 返回 30 天滚动窗口的下次重置时间(start + 30d)。
|
||||
// start 为 nil(未初始化)或已过期(now-start >= 30d,与 monthlyQuotaWindowExpired 同口径)时
|
||||
// 退化为 now+30d:过期窗口会在下次 increment 时重置为 now,下次重置即 now+30d;
|
||||
// 否则按 start 计算会得到一个过去的时间,使 Retry-After 落回 fallback 并触发客户端紧凑重试。
|
||||
func nextMonthlyResetFrom(start *time.Time, now time.Time) time.Time {
|
||||
if start == nil || now.Sub(*start) >= 30*24*time.Hour {
|
||||
return now.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
return start.Add(30 * 24 * time.Hour)
|
||||
}
|
||||
|
||||
// quotaWindowExpired 判断窗口是否已过期:start 为 nil(未初始化)或在 currWindowStart 之前视为已过期。
|
||||
func quotaWindowExpired(start *time.Time, currWindowStart time.Time) bool {
|
||||
if start == nil {
|
||||
return true
|
||||
}
|
||||
return start.Before(currWindowStart)
|
||||
}
|
||||
|
||||
// monthlyQuotaWindowExpired 判断 30 天滚动月度窗口是否已过期。
|
||||
// 过期条件:now - start >= 30×24h(与订阅模式 NeedsMonthlyReset 语义一致)。
|
||||
// start 为 nil 时视为已过期(未初始化窗口)。
|
||||
func monthlyQuotaWindowExpired(start *time.Time, now time.Time) bool {
|
||||
if start == nil {
|
||||
return true
|
||||
}
|
||||
return now.Sub(*start) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGrou
|
||||
t.Helper()
|
||||
// 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
|
||||
// 我们只直接测 checkRPM。
|
||||
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
|
||||
svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}, nil)
|
||||
t.Cleanup(svc.Stop)
|
||||
return svc
|
||||
}
|
||||
|
||||
@ -67,6 +67,22 @@ func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, ke
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
@ -100,7 +116,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}, nil)
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
|
||||
@ -68,9 +68,25 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil)
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
@ -92,7 +108,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}, nil)
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
|
||||
@ -0,0 +1,595 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
|
||||
// fakeIncrCache 仅记录 IncrUserPlatformQuotaUsageCache 被调用的参数。
|
||||
type fakeIncrCache struct {
|
||||
BillingCache
|
||||
calls []incrCall
|
||||
}
|
||||
|
||||
type incrCall struct {
|
||||
userID int64
|
||||
platform string
|
||||
cost float64
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeIncrCache) IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error {
|
||||
f.calls = append(f.calls, incrCall{userID, platform, cost, ttl})
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementUserPlatformQuotaUsage 已改为同步直写,不再走 worker。
|
||||
// 测试验证:同步调用立即调到 cache.IncrUserPlatformQuotaUsageCache。
|
||||
func TestIncrementUserPlatformQuotaUsage_SyncCallsCache(t *testing.T) {
|
||||
fake := &fakeIncrCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 120
|
||||
|
||||
s := &BillingCacheService{
|
||||
cache: fake,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
s.IncrementUserPlatformQuotaUsage(101, "anthropic", 0.25)
|
||||
s.IncrementUserPlatformQuotaUsage(101, "openai", 0.50)
|
||||
|
||||
if len(fake.calls) != 2 {
|
||||
t.Fatalf("expected 2 incr calls, got %d", len(fake.calls))
|
||||
}
|
||||
if fake.calls[0] != (incrCall{101, "anthropic", 0.25, 120 * time.Second}) {
|
||||
t.Errorf("call[0] = %+v", fake.calls[0])
|
||||
}
|
||||
if fake.calls[1] != (incrCall{101, "openai", 0.50, 120 * time.Second}) {
|
||||
t.Errorf("call[1] = %+v", fake.calls[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ── T6 tests: checkUserPlatformQuotaEligibility ──────────────────────────────
|
||||
|
||||
// fakeQuotaRepo 实现 UserPlatformQuotaRepository 最小子集
|
||||
type fakeQuotaRepo struct {
|
||||
rec *UserPlatformQuotaRecord
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) GetByUserPlatform(_ context.Context, _ int64, _ string) (*UserPlatformQuotaRecord, error) {
|
||||
return f.rec, nil
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) BulkInsertInitial(_ context.Context, _ []UserPlatformQuotaRecord) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) IncrementUsageWithReset(_ context.Context, _ int64, _ string, _ float64, _ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) ListByUser(_ context.Context, _ int64) ([]UserPlatformQuotaRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) UpsertForUser(_ context.Context, _ int64, _ []UserPlatformQuotaRecord) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeQuotaRepo) ResetExpiredWindow(_ context.Context, _ int64, _ string, _ string, _ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// fakeFullCache 同时支持 Get + Set + Incr + Delete。
|
||||
// mu 保护 entry 和 deleteCalls,防止异步 goroutine 与主 goroutine 之间的 data race。
|
||||
type fakeFullCache struct {
|
||||
BillingCache
|
||||
mu sync.Mutex
|
||||
entry *UserPlatformQuotaCacheEntry
|
||||
deleteCalls int
|
||||
}
|
||||
|
||||
// getDeleteCalls 线程安全地读取 deleteCalls。
|
||||
func (f *fakeFullCache) getDeleteCalls() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.deleteCalls
|
||||
}
|
||||
|
||||
// getEntry 线程安全地读取 entry。
|
||||
func (f *fakeFullCache) getEntry() *UserPlatformQuotaCacheEntry {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.entry
|
||||
}
|
||||
|
||||
func (f *fakeFullCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.entry == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
return f.entry, true, nil
|
||||
}
|
||||
|
||||
func (f *fakeFullCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, e *UserPlatformQuotaCacheEntry, _ time.Duration) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.entry = e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeFullCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.deleteCalls++
|
||||
f.entry = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServiceForPreflight(t *testing.T, repo UserPlatformQuotaRepository, cache BillingCache) *BillingCacheService {
|
||||
t.Helper()
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
|
||||
return &BillingCacheService{
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
userPlatformQuotaRepo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
// currentDayStart 返回全局时区当天 0 点(与生产 timezone.StartOfDay 同口径,确保窗口有效)。
|
||||
func currentDayStart() *time.Time {
|
||||
s := timezone.StartOfDay(time.Now())
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestCheckUserPlatformQuotaEligibility_AllowsWhenUnderLimit(t *testing.T) {
|
||||
daily := 10.0
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 4.5,
|
||||
DailyLimitUSD: &daily,
|
||||
DailyWindowStart: currentDayStart(),
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
|
||||
t.Errorf("expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUserPlatformQuotaEligibility_DailyExhausted(t *testing.T) {
|
||||
daily := 5.0
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 5.0,
|
||||
DailyLimitUSD: &daily,
|
||||
DailyWindowStart: currentDayStart(),
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
|
||||
t.Errorf("expected ErrUserPlatformDailyQuotaExhausted, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUserPlatformQuotaEligibility_NilLimitMeansUnlimited(t *testing.T) {
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic",
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 999,
|
||||
DailyWindowStart: currentDayStart(),
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
// DailyLimitUSD nil → 无限额
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
|
||||
t.Errorf("nil limits should be unlimited, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUserPlatformQuotaEligibility_ZeroLimitImmediateBlock(t *testing.T) {
|
||||
zero := 0.0
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &zero,
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 0,
|
||||
DailyLimitUSD: &zero,
|
||||
DailyWindowStart: currentDayStart(),
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
|
||||
t.Errorf("expected daily exhausted for limit=0, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUserPlatformQuotaEligibility_NoRecordMeansUnlimited(t *testing.T) {
|
||||
repo := &fakeQuotaRepo{rec: nil}
|
||||
cache := &fakeFullCache{}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
if err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic"); err != nil {
|
||||
t.Errorf("no record = unlimited, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB 验证旧版 entry(SchemaVersion=0)
|
||||
// 触发 DB 回退路径,并在 DB 数据判断配额是否超限。
|
||||
// DB record 需设置有效的 window_start,否则 quotaWindowExpired 会将 usage 归零(nil 窗口视为已过期)。
|
||||
func TestCheckUserPlatformQuotaEligibility_OldSchemaCacheMissTriggersDB(t *testing.T) {
|
||||
daily := 5.0
|
||||
dayStart := currentDayStart()
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily, DailyUsageUSD: 6.0,
|
||||
DailyWindowStart: dayStart,
|
||||
}}
|
||||
// SchemaVersion=0(旧 entry),应走 DB 路径
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{DailyUsageUSD: 1.0}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
|
||||
t.Errorf("旧版 entry 应走 DB 路径并报 daily exhausted, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache 验证 cache HIT 时若窗口已过期,usage 归零,用户放行。
|
||||
func TestCheckUserPlatformQuotaEligibility_WindowExpiredInCache(t *testing.T) {
|
||||
daily := 5.0
|
||||
past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) // 远古窗口起始,肯定已过期
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 10.0, // 超限,但窗口已过期
|
||||
DailyLimitUSD: &daily,
|
||||
DailyWindowStart: &past,
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if err != nil {
|
||||
t.Errorf("过期窗口应归零放行, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache 验证:
|
||||
// V1 HIT 路径检测到窗口过期时,用 reset 后的 entry 同步覆盖 Redis(而非 Delete):
|
||||
// 1. 当前请求以本地清零值判断 → 放行
|
||||
// 2. cache entry 被替换为新 entry: usage 清零 + window_start 更新到当前窗口
|
||||
// limit 保留;这样并发 IncrUserPlatformQuotaUsage 的 Lua INCR 可正确累加到新窗口。
|
||||
func TestCheckUserPlatformQuotaEligibility_WindowExpiredRefreshesCache(t *testing.T) {
|
||||
daily := 5.0
|
||||
// 远古窗口起始,确保 quotaWindowExpired 返回 true
|
||||
past := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
repo := &fakeQuotaRepo{rec: &UserPlatformQuotaRecord{
|
||||
UserID: 1, Platform: "anthropic", DailyLimitUSD: &daily,
|
||||
}}
|
||||
cache := &fakeFullCache{entry: &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 10.0, // 超限,但窗口已过期 → 应被本地清零后放行
|
||||
DailyLimitUSD: &daily,
|
||||
DailyWindowStart: &past,
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}}
|
||||
s := newServiceForPreflight(t, repo, cache)
|
||||
|
||||
// 本次 check 应放行(本地清零后 usage=0 < limit=5)
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if err != nil {
|
||||
t.Errorf("过期窗口应归零放行, got %v", err)
|
||||
}
|
||||
|
||||
// 验证 cache entry 已被刷新:usage 清零、limit 保留、window_start 更新到当前窗口
|
||||
refreshed := cache.getEntry()
|
||||
if refreshed == nil {
|
||||
t.Fatal("窗口过期后 cache entry 不应为 nil(应被 SetCache 覆盖,而非 Delete)")
|
||||
}
|
||||
if refreshed.DailyUsageUSD != 0 {
|
||||
t.Errorf("刷新后 DailyUsageUSD = %v, want 0", refreshed.DailyUsageUSD)
|
||||
}
|
||||
if refreshed.DailyLimitUSD == nil || *refreshed.DailyLimitUSD != daily {
|
||||
t.Errorf("刷新后 DailyLimitUSD = %v, want %v(保留)", refreshed.DailyLimitUSD, daily)
|
||||
}
|
||||
if refreshed.SchemaVersion != UserPlatformQuotaCacheSchemaV1 {
|
||||
t.Errorf("刷新后 SchemaVersion = %d, want V1", refreshed.SchemaVersion)
|
||||
}
|
||||
if refreshed.DailyWindowStart == nil || refreshed.DailyWindowStart.Equal(past) {
|
||||
t.Errorf("刷新后 DailyWindowStart = %v, 应更新到当前窗口而非保留 past=%v", refreshed.DailyWindowStart, past)
|
||||
}
|
||||
}
|
||||
|
||||
// ── T5 tests: QueueUpdateUserPlatformQuotaUsage ───────────────────────────────
|
||||
|
||||
// ── C-NEW-1: monthlyQuotaWindowExpired 30 天滚动测试 ─────────────────────────
|
||||
|
||||
func TestMonthlyQuotaWindowExpired_NilStart(t *testing.T) {
|
||||
if !monthlyQuotaWindowExpired(nil, time.Now().UTC()) {
|
||||
t.Error("nil start should be considered expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonthlyQuotaWindowExpired_Expired(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-30 * 24 * time.Hour)
|
||||
if !monthlyQuotaWindowExpired(&start, time.Now().UTC()) {
|
||||
t.Error("start exactly 30 days ago should be expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonthlyQuotaWindowExpired_Active(t *testing.T) {
|
||||
start := time.Now().UTC().Add(-29 * 24 * time.Hour)
|
||||
if monthlyQuotaWindowExpired(&start, time.Now().UTC()) {
|
||||
t.Error("start 29 days ago should NOT be expired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMonthlyQuotaWindowExpired_CrossMonthBoundary 验证跨自然月时 30 天未满不视为过期。
|
||||
func TestMonthlyQuotaWindowExpired_CrossMonthBoundary(t *testing.T) {
|
||||
// 窗口起始 4 月 20 日;5 月 1 日只过了 11 天,不足 30 天
|
||||
start := time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC)
|
||||
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC)
|
||||
if monthlyQuotaWindowExpired(&start, now) {
|
||||
t.Error("11 days into window should NOT be expired (30-day rolling, not calendar month)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextMonthlyResetFrom 验证 30 天滚动重置时间计算。
|
||||
func TestNextMonthlyResetFrom_WithStart(t *testing.T) {
|
||||
start := time.Date(2026, 5, 1, 10, 0, 0, 0, time.UTC)
|
||||
want := start.Add(30 * 24 * time.Hour)
|
||||
now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
|
||||
got := nextMonthlyResetFrom(&start, now)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("nextMonthlyResetFrom = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextMonthlyResetFrom_NilStart(t *testing.T) {
|
||||
now := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
|
||||
got := nextMonthlyResetFrom(nil, now)
|
||||
want := now.Add(30 * 24 * time.Hour)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("nextMonthlyResetFrom(nil) = %v, want now+30d=%v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextMonthlyResetFrom_NilStart_NotEqualToNow(t *testing.T) {
|
||||
now := time.Date(2026, 5, 22, 12, 0, 0, 0, time.UTC)
|
||||
got := nextMonthlyResetFrom(nil, now)
|
||||
want := now.Add(30 * 24 * time.Hour)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("nextMonthlyResetFrom(nil) = %v, want %v (now+30d)", got, want)
|
||||
}
|
||||
if got.Equal(now) {
|
||||
t.Error("nextMonthlyResetFrom(nil) must not return now (should be now+30d)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNextMonthlyResetFrom_ExpiredStart 验证窗口已过期(now-start >= 30d)时,
|
||||
// 下次重置时间为 now+30d,而非 start+30d(后者已是过去时间,会让 Retry-After 落回
|
||||
// fallback 并触发客户端紧凑重试)。
|
||||
func TestNextMonthlyResetFrom_ExpiredStart(t *testing.T) {
|
||||
start := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
|
||||
now := time.Date(2026, 5, 1, 0, 0, 0, 0, time.UTC) // 距 start 61 天,已过期
|
||||
got := nextMonthlyResetFrom(&start, now)
|
||||
want := now.Add(30 * 24 * time.Hour)
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("nextMonthlyResetFrom(expired) = %v, want now+30d=%v", got, want)
|
||||
}
|
||||
if !got.After(now) {
|
||||
t.Error("expired window 的下次重置必须在 now 之后,不能是过去时间")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementUserPlatformQuotaUsage_GuardsAgainstEmpty(t *testing.T) {
|
||||
fake := &fakeIncrCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
|
||||
s := &BillingCacheService{
|
||||
cache: fake,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
s.IncrementUserPlatformQuotaUsage(1, "", 0.5) // empty platform → noop
|
||||
s.IncrementUserPlatformQuotaUsage(1, "openai", 0) // zero cost → noop
|
||||
s.IncrementUserPlatformQuotaUsage(1, "openai", -0.1) // negative → noop
|
||||
|
||||
if len(fake.calls) != 0 {
|
||||
t.Errorf("expected 0 calls (all guarded), got %d", len(fake.calls))
|
||||
}
|
||||
}
|
||||
|
||||
// ── C-NEW-2: 订阅模式豁免 user×platform quota 检查 ──────────────────────────
|
||||
// 通过直接调用 checkUserPlatformQuotaEligibility 验证:
|
||||
// 1. standard 模式下 limit=0 → 拦截
|
||||
// 2. 订阅模式豁免通过 isSubscriptionMode 守卫体现 — 逻辑已在 CheckBillingEligibility 里加 !isSubscriptionMode 条件
|
||||
// 此处用单元测试直接验证底层 checkUserPlatformQuotaEligibility 的行为(quota 超限确实拦截),
|
||||
// 而 subscription bypass 逻辑则在 CheckBillingEligibility 中通过条件判断保证,不绕过 sub eligibility 内部复杂依赖。
|
||||
|
||||
// fakeZeroQuotaCache 模拟 cache 命中且 daily limit=0(quota 耗尽)。
|
||||
type fakeZeroQuotaCache struct {
|
||||
BillingCache
|
||||
called bool
|
||||
}
|
||||
|
||||
func (f *fakeZeroQuotaCache) GetUserPlatformQuotaCache(_ context.Context, _ int64, _ string) (*UserPlatformQuotaCacheEntry, bool, error) {
|
||||
f.called = true
|
||||
daily := 0.0
|
||||
entry := &UserPlatformQuotaCacheEntry{
|
||||
DailyUsageUSD: 0,
|
||||
DailyLimitUSD: &daily,
|
||||
DailyWindowStart: func() *time.Time { t := time.Now().UTC(); return &t }(),
|
||||
SchemaVersion: UserPlatformQuotaCacheSchemaV1,
|
||||
}
|
||||
return entry, true, nil
|
||||
}
|
||||
|
||||
func (f *fakeZeroQuotaCache) DeleteUserPlatformQuotaCache(_ context.Context, _ int64, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetUserPlatformQuotaCache 在 weekly/monthly window_start 为 nil 时,checkUserPlatform...
|
||||
// 会触发"窗口过期 → SetCache 刷新"分支。fake 用 noop 避免 panic。
|
||||
func (f *fakeZeroQuotaCache) SetUserPlatformQuotaCache(_ context.Context, _ int64, _ string, _ *UserPlatformQuotaCacheEntry, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSubscriptionCache 返回有效订阅(active、未过期、usage 远低于 limit),
|
||||
// 用于支持 checkSubscriptionEligibility 通过,以便验证 quota 检查不被触发。
|
||||
func (f *fakeZeroQuotaCache) GetSubscriptionCache(_ context.Context, _ int64, _ int64) (*SubscriptionCacheData, error) {
|
||||
return &SubscriptionCacheData{
|
||||
Status: SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
DailyUsage: 0,
|
||||
WeeklyUsage: 0,
|
||||
MonthlyUsage: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeZeroQuotaCache) GetUserBalanceCache(_ context.Context, _ int64) (float64, bool, error) {
|
||||
return 100.0, true, nil
|
||||
}
|
||||
|
||||
// TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero 验证:
|
||||
// standard 模式下 limit=0 的 platform quota 确实会被拦截(守卫底层逻辑正确)。
|
||||
func TestCheckUserPlatformQuotaEligibility_StandardMode_BlocksWhenLimitZero(t *testing.T) {
|
||||
fake := &fakeZeroQuotaCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
|
||||
s := &BillingCacheService{
|
||||
cache: fake,
|
||||
cfg: cfg,
|
||||
userPlatformQuotaRepo: &fakeQuotaRepo{},
|
||||
}
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 1, "anthropic")
|
||||
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
|
||||
t.Errorf("standard mode with limit=0 should return ErrUserPlatformDailyQuotaExhausted, got: %v", err)
|
||||
}
|
||||
if !fake.called {
|
||||
t.Error("GetUserPlatformQuotaCache should have been called in standard mode")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota 验证(C-NEW-2):
|
||||
// 订阅模式用户不受 user×platform quota 拦截,GetUserPlatformQuotaCache 不应被调用。
|
||||
func TestCheckBillingEligibility_SubscriptionMode_BypassesPlatformQuota(t *testing.T) {
|
||||
fake := &fakeZeroQuotaCache{} // GetUserPlatformQuotaCache 返回 limit=0,若被调用则拦截
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
|
||||
s := &BillingCacheService{
|
||||
cache: fake,
|
||||
cfg: cfg,
|
||||
userPlatformQuotaRepo: &fakeQuotaRepo{},
|
||||
}
|
||||
|
||||
subGroup := &Group{
|
||||
ID: 10,
|
||||
SubscriptionType: "subscription",
|
||||
Status: "active",
|
||||
// 无 DailyLimitUSD → checkSubscriptionEligibility 不会因超限失败
|
||||
}
|
||||
sub := &UserSubscription{Status: "active"}
|
||||
user := &User{ID: 42}
|
||||
|
||||
err := s.CheckBillingEligibility(context.Background(), user, nil, subGroup, sub, "anthropic")
|
||||
// 订阅模式下不应收到任何 user×platform quota 错误
|
||||
if errors.Is(err, ErrUserPlatformDailyQuotaExhausted) ||
|
||||
errors.Is(err, ErrUserPlatformWeeklyQuotaExhausted) ||
|
||||
errors.Is(err, ErrUserPlatformMonthlyQuotaExhausted) {
|
||||
t.Errorf("subscription mode should bypass user×platform quota, got: %v", err)
|
||||
}
|
||||
// GetUserPlatformQuotaCache 不应被调用
|
||||
if fake.called {
|
||||
t.Error("GetUserPlatformQuotaCache must NOT be called in subscription mode (C-NEW-2)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota 验证:
|
||||
// 非订阅模式(group=nil)用户 platform quota 超限时被拦截,quota cache 被查询。
|
||||
func TestCheckBillingEligibility_NonSubscriptionGroup_AppliesQuota(t *testing.T) {
|
||||
called := &fakeZeroQuotaCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Billing.UserPlatformQuotaCacheTTLSeconds = 60
|
||||
s := &BillingCacheService{
|
||||
cache: called,
|
||||
cfg: cfg,
|
||||
userPlatformQuotaRepo: &fakeQuotaRepo{},
|
||||
}
|
||||
err := s.checkUserPlatformQuotaEligibility(context.Background(), 99, "openai")
|
||||
if !errors.Is(err, ErrUserPlatformDailyQuotaExhausted) {
|
||||
t.Errorf("non-subscription mode quota check should block, got: %v", err)
|
||||
}
|
||||
if !called.called {
|
||||
t.Error("GetUserPlatformQuotaCache should be consulted in non-subscription mode")
|
||||
}
|
||||
}
|
||||
|
||||
// ── B-3: monthlyQuotaWindowExpired 30 天边界表驱动测试 ────────────────────────
|
||||
// 覆盖 4 个必须场景:
|
||||
// 1. 恰好 30 天 → expired
|
||||
// 2. 30*24h - 1ns → not expired
|
||||
// 3. 跨月末(2024-02-28 → 2024-03-29T00:00:01Z)→ expired
|
||||
// 4. 跨年(2024-12-15 → 2025-01-14T00:00:01Z)→ expired
|
||||
//
|
||||
// repo 层 monthlyMaybeReset 不可导出,通过 service 层 monthlyQuotaWindowExpired 间接覆盖。
|
||||
func TestMonthlyQuotaWindowExpired_BoundaryTable(t *testing.T) {
|
||||
const thirtyDays = 30 * 24 * time.Hour
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
start time.Time
|
||||
now time.Time
|
||||
expired bool
|
||||
}{
|
||||
{
|
||||
name: "exactly 30 days → expired",
|
||||
start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays),
|
||||
expired: true,
|
||||
},
|
||||
{
|
||||
name: "30d minus 1ns → not expired",
|
||||
start: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
now: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).Add(thirtyDays - 1),
|
||||
expired: false,
|
||||
},
|
||||
{
|
||||
name: "cross month-end (Feb→Mar, 29d+1s) → expired",
|
||||
start: time.Date(2024, 2, 28, 0, 0, 0, 0, time.UTC),
|
||||
now: time.Date(2024, 3, 29, 0, 0, 1, 0, time.UTC),
|
||||
expired: true,
|
||||
},
|
||||
{
|
||||
name: "cross year boundary (Dec→Jan, 30d+1s) → expired",
|
||||
start: time.Date(2024, 12, 15, 0, 0, 0, 0, time.UTC),
|
||||
now: time.Date(2025, 1, 14, 0, 0, 1, 0, time.UTC),
|
||||
expired: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := monthlyQuotaWindowExpired(&tc.start, tc.now)
|
||||
if got != tc.expired {
|
||||
t.Errorf("monthlyQuotaWindowExpired(start=%v, now=%v) = %v, want %v",
|
||||
tc.start, tc.now, got, tc.expired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
@ -20,6 +21,32 @@ type APIKeyRateLimitCacheData struct {
|
||||
Window7d int64 `json:"window_7d"`
|
||||
}
|
||||
|
||||
// UserPlatformQuotaCacheEntry Redis hash 反序列化结果。
|
||||
//
|
||||
// SchemaVersion 用于向后兼容:
|
||||
// - 0(旧 entry,无 SchemaVersion 字段)→ 视为 cache MISS,强制 refresh
|
||||
// - 1(当前版本)→ 包含 limits 和 window_start,可免 DB 查询
|
||||
//
|
||||
// limit 字段为 nil 表示"无限额"(DB 中对应列为 NULL)。
|
||||
const UserPlatformQuotaCacheSchemaV1 = int64(1)
|
||||
|
||||
type UserPlatformQuotaCacheEntry struct {
|
||||
DailyUsageUSD float64
|
||||
WeeklyUsageUSD float64
|
||||
MonthlyUsageUSD float64
|
||||
Version int64
|
||||
SchemaVersion int64
|
||||
|
||||
// 以下字段仅在 SchemaVersion >= 1 时有效
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
|
||||
DailyWindowStart *time.Time
|
||||
WeeklyWindowStart *time.Time
|
||||
MonthlyWindowStart *time.Time
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
@ -39,6 +66,13 @@ type BillingCache interface {
|
||||
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
|
||||
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
|
||||
// user × platform quota 缓存
|
||||
GetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) (*UserPlatformQuotaCacheEntry, bool, error)
|
||||
SetUserPlatformQuotaCache(ctx context.Context, userID int64, platform string, entry *UserPlatformQuotaCacheEntry, ttl time.Duration) error
|
||||
DeleteUserPlatformQuotaCache(ctx context.Context, userID int64, platform string) error
|
||||
// IncrUserPlatformQuotaUsageCache 在缓存命中时累加用量;缓存未命中(key 不存在)静默返回 nil。
|
||||
IncrUserPlatformQuotaUsageCache(ctx context.Context, userID int64, platform string, cost float64, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
package service
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
)
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
@ -39,6 +43,26 @@ const (
|
||||
PlatformAntigravity = domain.PlatformAntigravity
|
||||
)
|
||||
|
||||
// AllowedQuotaPlatforms 是允许设置 user × platform quota 的平台列表(单一权威来源)。
|
||||
// ent/schema/user_platform_quota.go 的 Validate 函数独立维护(构建期约束),
|
||||
// 若新增平台需同步修改该 schema。
|
||||
var AllowedQuotaPlatforms = []string{
|
||||
PlatformAnthropic,
|
||||
PlatformOpenAI,
|
||||
PlatformGemini,
|
||||
PlatformAntigravity,
|
||||
}
|
||||
|
||||
// IsAllowedQuotaPlatform 报告 s 是否为合法的 quota platform 标识。
|
||||
func IsAllowedQuotaPlatform(s string) bool {
|
||||
for _, p := range AllowedQuotaPlatforms {
|
||||
if p == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
@ -424,5 +448,15 @@ const (
|
||||
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
|
||||
)
|
||||
|
||||
// SettingKeyDefaultPlatformQuotas —— 系统全局:每用户 × 平台日/周/月 USD 上限(JSON)。
|
||||
// 值为 map[platform]{daily,weekly,monthly},null/缺省 = 不限制;0 = 禁用;>0 = USD 上限。
|
||||
const SettingKeyDefaultPlatformQuotas = "default_platform_quotas"
|
||||
|
||||
// SettingKeyAuthSourcePlatformQuotas 返回某 auth source 的 platform quota JSON key。
|
||||
// 形如 auth_source_default_{source}_platform_quotas
|
||||
func SettingKeyAuthSourcePlatformQuotas(source string) string {
|
||||
return fmt.Sprintf("auth_source_default_%s_platform_quotas", source)
|
||||
}
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
23
backend/internal/service/domain_constants_test.go
Normal file
23
backend/internal/service/domain_constants_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestSettingKeyDefaultPlatformQuotas 验证新的系统层 JSON key 常量值正确。
|
||||
func TestSettingKeyDefaultPlatformQuotas(t *testing.T) {
|
||||
if SettingKeyDefaultPlatformQuotas != "default_platform_quotas" {
|
||||
t.Errorf("SettingKeyDefaultPlatformQuotas = %q, want %q",
|
||||
SettingKeyDefaultPlatformQuotas, "default_platform_quotas")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettingKeyAuthSourcePlatformQuotas 验证新的 auth-source JSON key 函数返回值正确。
|
||||
func TestSettingKeyAuthSourcePlatformQuotas(t *testing.T) {
|
||||
if got := SettingKeyAuthSourcePlatformQuotas("email"); got != "auth_source_default_email_platform_quotas" {
|
||||
t.Fatalf("got %q, want %q", got, "auth_source_default_email_platform_quotas")
|
||||
}
|
||||
if got := SettingKeyAuthSourcePlatformQuotas("dingtalk"); got != "auth_source_default_dingtalk_platform_quotas" {
|
||||
t.Fatalf("got %q, want %q", got, "auth_source_default_dingtalk_platform_quotas")
|
||||
}
|
||||
}
|
||||
@ -44,6 +44,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -95,6 +95,16 @@ var (
|
||||
modelsListCacheHitTotal atomic.Int64
|
||||
modelsListCacheMissTotal atomic.Int64
|
||||
modelsListCacheStoreTotal atomic.Int64
|
||||
|
||||
// userPlatformQuotaDBIncrErrorTotal 统计 finalizePostUsageBilling 异步 goroutine
|
||||
// 中 IncrementUsageWithReset 失败次数。Redis 已成功累加 + DB 写失败意味着
|
||||
// Redis cache TTL 过期或被清后该笔 cost 会丢失(与实际消费偏差)。
|
||||
// oncall 通过 GatewayUserPlatformQuotaIncrStats() 暴露给 ops 面板做阈值告警。
|
||||
userPlatformQuotaDBIncrErrorTotal atomic.Int64
|
||||
// userPlatformQuotaDBIncrLegacyErrorTotal 统计 legacy postUsageBilling
|
||||
// (applyUsageBilling 在 repo==nil 时 fallback)路径下的失败次数;
|
||||
// 与 DB Incr 失败分开计数,便于区分"主路径暂时故障"vs"基础设施长期未配齐"。
|
||||
userPlatformQuotaDBIncrLegacyErrorTotal atomic.Int64
|
||||
)
|
||||
|
||||
func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) {
|
||||
@ -117,6 +127,15 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||
}
|
||||
|
||||
// GatewayUserPlatformQuotaIncrStats 返回 (mainPathErr, legacyPathErr)。
|
||||
// mainPathErr:finalizePostUsageBilling 异步 goroutine 写 DB 失败累计次数;
|
||||
// legacyPathErr:postUsageBilling fallback 路径写 DB 失败累计次数。
|
||||
// ops 监控面板可以按"持续上升斜率"做告警阈值。
|
||||
func GatewayUserPlatformQuotaIncrStats() (mainPathErr, legacyPathErr int64) {
|
||||
return userPlatformQuotaDBIncrErrorTotal.Load(),
|
||||
userPlatformQuotaDBIncrLegacyErrorTotal.Load()
|
||||
}
|
||||
|
||||
func openAIStreamEventIsTerminal(data string) bool {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
@ -575,6 +594,7 @@ type GatewayService struct {
|
||||
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@ -605,41 +625,43 @@ func NewGatewayService(
|
||||
channelService *ChannelService,
|
||||
resolver *ModelPricingResolver,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
digestStore: digestStore,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
resolver: resolver,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
digestStore: digestStore,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||
settingService: settingService,
|
||||
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||
modelsListCacheTTL: modelsListTTL,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
resolver: resolver,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
@ -7949,6 +7971,7 @@ type RecordUsageInput struct {
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform)
|
||||
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
@ -7978,6 +8001,31 @@ type postUsageBillingParams struct {
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
Platform string // 来自 APIKey 关联 Group 的平台标识
|
||||
}
|
||||
|
||||
// PlatformFromAPIKey 从 APIKey 关联的 Group 推导 platform 名称。
|
||||
// apiKey 为 nil 或 Group 信息缺失时返回空串(调用方据此 short-circuit quota 累加)。
|
||||
// 导出供 handler 层调用。
|
||||
func PlatformFromAPIKey(apiKey *APIKey) string {
|
||||
if apiKey == nil || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
|
||||
// QuotaPlatform 返回 user×platform 配额计量使用的平台标识。
|
||||
// 强制平台路由(如 /antigravity)优先按 ctx 中的 ForcePlatform 计量,否则回退到
|
||||
// APIKey 关联 Group 的平台。
|
||||
//
|
||||
// 注意:必须用带 ForcePlatform 的请求 context 调用(如 handler 的 c.Request.Context())。
|
||||
// 后扣运行在 worker 池的 background ctx 上没有 ForcePlatform,因此后扣平台由 handler
|
||||
// 预先算定、经 RecordUsageInput.QuotaPlatform 传入,不要在后扣链路用 worker ctx 调用本函数。
|
||||
func QuotaPlatform(ctx context.Context, apiKey *APIKey) string {
|
||||
if fp, ok := ctx.Value(ctxkey.ForcePlatform).(string); ok && fp != "" {
|
||||
return fp
|
||||
}
|
||||
return PlatformFromAPIKey(apiKey)
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
|
||||
@ -8036,6 +8084,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
}
|
||||
|
||||
// Platform quota DB-only 累加(与 finalizePostUsageBilling 行为对齐的兜底):
|
||||
// - 仅对 standard(余额)模式生效;订阅模式豁免
|
||||
// - 直接走 DB,不经 Redis Incr 队列:legacy 路径在 repo==nil(仓库未注入)
|
||||
// 时被触发,此时整套 billing repo 都不可用,没有"双队列"风险
|
||||
// - 失败仅记 ALERT log + counter,不阻断主扣费流程;与正常路径一致
|
||||
//
|
||||
// 历史背景:原 legacy path 完全跳过此累加,导致部署中如果 repo 偶然为 nil
|
||||
// 时用户消费可绕过 platform quota,存在静默资金风险。
|
||||
if !p.IsSubscriptionBill && p.Platform != "" && cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil {
|
||||
if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(billingCtx, p.User.ID, p.Platform, cost.ActualCost, time.Now().UTC()); err != nil {
|
||||
userPlatformQuotaDBIncrLegacyErrorTotal.Add(1)
|
||||
logger.LegacyPrintf("service.gateway", "ALERT: legacy incr user platform quota DB failed user=%d platform=%s cost=%f: %v", p.User.ID, p.Platform, cost.ActualCost, err)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
|
||||
// cache updates. The legacy path does DB writes directly; the finalize path
|
||||
// does cache queue + notifications. Notifications are dispatched separately
|
||||
@ -8159,11 +8222,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps, result)
|
||||
finalizePostUsageBilling(billingCtx, p, deps, result)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||
func finalizePostUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||
if p == nil || p.Cost == nil || deps == nil {
|
||||
return
|
||||
}
|
||||
@ -8182,6 +8245,32 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
|
||||
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
|
||||
// Platform quota 累加:仅在 standard(余额)模式生效;订阅模式豁免
|
||||
// Redis 同步写 + DB 异步持久化:
|
||||
// - Redis 同步:确保下次 preflight 立即看到最新 usage,把 TOCTOU 超支窗口
|
||||
// 限制在并发 in-flight 请求数量内(旧实现的异步入队会让超支无限累积直到 worker 处理)
|
||||
// - DB 异步:在独立 goroutine 中走 detached context,失败用 ALERT log 触发 oncall 对账
|
||||
if !p.IsSubscriptionBill && p.Platform != "" && p.Cost.ActualCost > 0 && p.User != nil && deps.userPlatformQuotaRepo != nil {
|
||||
deps.billingCacheService.IncrementUserPlatformQuotaUsage(p.User.ID, p.Platform, p.Cost.ActualCost)
|
||||
dbCtx, dbCancel := detachUpstreamContext(ctx)
|
||||
userID, platform, cost := p.User.ID, p.Platform, p.Cost.ActualCost
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LegacyPrintf("service.gateway", "ALERT: panic in user platform quota incr goroutine user=%d platform=%s: %v", userID, platform, r)
|
||||
}
|
||||
}()
|
||||
defer dbCancel()
|
||||
if err := deps.userPlatformQuotaRepo.IncrementUsageWithReset(dbCtx, userID, platform, cost, time.Now().UTC()); err != nil {
|
||||
// 失败计数器:暴露给 GatewayUserPlatformQuotaIncrStats(),由 ops 面板做斜率告警。
|
||||
userPlatformQuotaDBIncrErrorTotal.Add(1)
|
||||
// ALERT 级别:DB 持久化失败意味着 Redis cache 失效后该笔 cost 永久丢失,
|
||||
// 用户配额视图与实际消费会偏差,oncall 需要据此对账或人工补录。
|
||||
logger.LegacyPrintf("service.gateway", "ALERT: incr user platform quota DB failed user=%d platform=%s cost=%f: %v", userID, platform, cost, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Notification checks run async — all parameters are already captured,
|
||||
// no dependency on the request context or upstream connection.
|
||||
go notifyBalanceLow(p, deps, result)
|
||||
@ -8287,22 +8376,24 @@ func detachUpstreamContext(ctx context.Context) (context.Context, context.Cancel
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
deferredService *DeferredService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
accountRepo AccountRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
deferredService *DeferredService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||||
}
|
||||
|
||||
func (s *GatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
balanceNotifyService: s.balanceNotifyService,
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
balanceNotifyService: s.balanceNotifyService,
|
||||
userPlatformQuotaRepo: s.userPlatformQuotaRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@ -8360,6 +8451,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
RequestPayloadHash: input.RequestPayloadHash,
|
||||
ForceCacheBilling: input.ForceCacheBilling,
|
||||
APIKeyService: input.APIKeyService,
|
||||
QuotaPlatform: input.QuotaPlatform,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
EnableClaudePath: true,
|
||||
@ -8382,6 +8474,7 @@ type RecordUsageLongContextInput struct {
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
QuotaPlatform string // user×platform 配额计量平台:handler 在请求 ctx 内经 QuotaPlatform() 算定后传入(后扣运行在 worker 池 background ctx 上,取不到 ForcePlatform)
|
||||
|
||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||
}
|
||||
@ -8401,6 +8494,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
RequestPayloadHash: input.RequestPayloadHash,
|
||||
ForceCacheBilling: input.ForceCacheBilling,
|
||||
APIKeyService: input.APIKeyService,
|
||||
QuotaPlatform: input.QuotaPlatform,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
LongContextThreshold: input.LongContextThreshold,
|
||||
@ -8422,6 +8516,7 @@ type recordUsageCoreInput struct {
|
||||
RequestPayloadHash string
|
||||
ForceCacheBilling bool
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
QuotaPlatform string
|
||||
ChannelUsageFields
|
||||
}
|
||||
|
||||
@ -8519,6 +8614,13 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
return nil
|
||||
}
|
||||
|
||||
// 配额平台由 handler 在请求 ctx 内经 QuotaPlatform() 算定并通过 input 传入;
|
||||
// 后扣运行在 worker 池的 background ctx 上,无法再从 ctx 取 ForcePlatform。
|
||||
// 缺省(未设置)时回退到分组平台,保持对其它调用方的兼容。
|
||||
quotaPlatform := input.QuotaPlatform
|
||||
if quotaPlatform == "" {
|
||||
quotaPlatform = PlatformFromAPIKey(apiKey)
|
||||
}
|
||||
requestID := usageLog.RequestID
|
||||
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
@ -8530,6 +8632,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
Platform: quotaPlatform,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
|
||||
if billingErr != nil {
|
||||
|
||||
@ -155,6 +155,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
|
||||
@ -343,6 +343,7 @@ type OpenAIGatewayService struct {
|
||||
channelService *ChannelService
|
||||
balanceNotifyService *BalanceNotifyService
|
||||
settingService *SettingService
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@ -387,6 +388,7 @@ func NewOpenAIGatewayService(
|
||||
channelService *ChannelService,
|
||||
balanceNotifyService *BalanceNotifyService,
|
||||
settingService *SettingService,
|
||||
userPlatformQuotaRepo UserPlatformQuotaRepository,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@ -418,6 +420,7 @@ func NewOpenAIGatewayService(
|
||||
channelService: channelService,
|
||||
balanceNotifyService: balanceNotifyService,
|
||||
settingService: settingService,
|
||||
userPlatformQuotaRepo: userPlatformQuotaRepo,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@ -523,12 +526,13 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
|
||||
|
||||
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||
return &billingDeps{
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
balanceNotifyService: s.balanceNotifyService,
|
||||
accountRepo: s.accountRepo,
|
||||
userRepo: s.userRepo,
|
||||
userSubRepo: s.userSubRepo,
|
||||
billingCacheService: s.billingCacheService,
|
||||
deferredService: s.deferredService,
|
||||
balanceNotifyService: s.balanceNotifyService,
|
||||
userPlatformQuotaRepo: s.userPlatformQuotaRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@ -5634,6 +5638,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
Platform: PlatformFromAPIKey(apiKey),
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
@ -619,6 +619,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil, // userPlatformQuotaRepo
|
||||
)
|
||||
|
||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||
|
||||
81
backend/internal/service/post_billing_platform_test.go
Normal file
81
backend/internal/service/post_billing_platform_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
func TestPlatformFromAPIKey_NilSafe(t *testing.T) {
|
||||
if got := PlatformFromAPIKey(nil); got != "" {
|
||||
t.Errorf("nil APIKey should yield empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformFromAPIKey_NilGroup(t *testing.T) {
|
||||
k := &APIKey{Group: nil}
|
||||
if got := PlatformFromAPIKey(k); got != "" {
|
||||
t.Errorf("APIKey with nil Group should yield empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformFromAPIKey_DerivesFromGroup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{"anthropic", "anthropic"},
|
||||
{"openai", "openai"},
|
||||
{"gemini", "gemini"},
|
||||
{"antigravity", "antigravity"},
|
||||
{"empty", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &APIKey{
|
||||
Group: &Group{Platform: tt.platform},
|
||||
}
|
||||
got := PlatformFromAPIKey(k)
|
||||
if got != tt.platform {
|
||||
t.Errorf("PlatformFromAPIKey(%q) = %q, want %q", tt.platform, got, tt.platform)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestQuotaPlatform 锁定配额计量口径:ForcePlatform 路由(如 /antigravity)按 ForcePlatform 计,
|
||||
// 否则回退到 Group 平台。preflight 与 post-billing 共用此口径,保证一致。
|
||||
func TestQuotaPlatform(t *testing.T) {
|
||||
apiKey := &APIKey{Group: &Group{Platform: PlatformAnthropic}}
|
||||
|
||||
t.Run("no force platform falls back to group platform", func(t *testing.T) {
|
||||
if got := QuotaPlatform(context.Background(), apiKey); got != PlatformAnthropic {
|
||||
t.Errorf("QuotaPlatform without force = %q, want %q", got, PlatformAnthropic)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("force platform overrides group platform", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity)
|
||||
if got := QuotaPlatform(ctx, apiKey); got != PlatformAntigravity {
|
||||
t.Errorf("QuotaPlatform with force = %q, want %q", got, PlatformAntigravity)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty force platform falls back to group platform", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, "")
|
||||
if got := QuotaPlatform(ctx, apiKey); got != PlatformAnthropic {
|
||||
t.Errorf("QuotaPlatform with empty force = %q, want %q", got, PlatformAnthropic)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil api key with force platform returns force platform", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAntigravity)
|
||||
if got := QuotaPlatform(ctx, nil); got != PlatformAntigravity {
|
||||
t.Errorf("QuotaPlatform(nil) with force = %q, want %q", got, PlatformAntigravity)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -165,12 +165,20 @@ type SettingService struct {
|
||||
openAICodexUASF singleflight.Group
|
||||
}
|
||||
|
||||
// DefaultPlatformQuotaSetting 单 platform 三档限额(nil = 沿用上层;0 = 显式禁用;>0 = 上限)
|
||||
type DefaultPlatformQuotaSetting struct {
|
||||
DailyLimitUSD *float64 `json:"daily"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly"`
|
||||
}
|
||||
|
||||
type ProviderDefaultGrantSettings struct {
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Subscriptions []DefaultSubscriptionSetting
|
||||
GrantOnSignup bool
|
||||
GrantOnFirstBind bool
|
||||
PlatformQuotas map[string]*DefaultPlatformQuotaSetting // key = platform name
|
||||
}
|
||||
|
||||
type AuthSourceDefaultSettings struct {
|
||||
@ -185,62 +193,80 @@ type AuthSourceDefaultSettings struct {
|
||||
}
|
||||
|
||||
type authSourceDefaultKeySet struct {
|
||||
// source 是 auth source 标识(如 "email"、"github"),仅用于 parse 时
|
||||
// slog.Warn 诊断输出,不再参与 key 拼接(platformQuotas 字段已存完整 key)。
|
||||
source string
|
||||
balance string
|
||||
concurrency string
|
||||
subscriptions string
|
||||
grantOnSignup string
|
||||
grantOnFirstBind string
|
||||
platformQuotas string // SettingKeyAuthSourcePlatformQuotas(source)
|
||||
}
|
||||
|
||||
var (
|
||||
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "email",
|
||||
balance: SettingKeyAuthSourceDefaultEmailBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("email"),
|
||||
}
|
||||
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "linuxdo",
|
||||
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("linuxdo"),
|
||||
}
|
||||
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "oidc",
|
||||
balance: SettingKeyAuthSourceDefaultOIDCBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("oidc"),
|
||||
}
|
||||
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "wechat",
|
||||
balance: SettingKeyAuthSourceDefaultWeChatBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("wechat"),
|
||||
}
|
||||
gitHubAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "github",
|
||||
balance: SettingKeyAuthSourceDefaultGitHubBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGitHubConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGitHubSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGitHubGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("github"),
|
||||
}
|
||||
googleAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "google",
|
||||
balance: SettingKeyAuthSourceDefaultGoogleBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultGoogleConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultGoogleSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("google"),
|
||||
}
|
||||
dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
source: "dingtalk",
|
||||
balance: SettingKeyAuthSourceDefaultDingTalkBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
|
||||
platformQuotas: SettingKeyAuthSourcePlatformQuotas("dingtalk"),
|
||||
}
|
||||
)
|
||||
|
||||
@ -1804,9 +1830,41 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
|
||||
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
|
||||
|
||||
// 系统全局 platform quota:整体替换语义(null/缺省 = 不限制)。
|
||||
if settings.DefaultPlatformQuotas != nil {
|
||||
if err := validateDefaultPlatformQuotaMap(settings.DefaultPlatformQuotas); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blob, err := json.Marshal(settings.DefaultPlatformQuotas)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal default platform quotas: %w", err)
|
||||
}
|
||||
updates[SettingKeyDefaultPlatformQuotas] = string(blob)
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
// validateDefaultPlatformQuotaMap 校验 platform quota map 的合法性:
|
||||
// 平台名须在 AllowedQuotaPlatforms 白名单内,每个非 nil 上限须 finite 且 >= 0。
|
||||
// 系统层和 auth-source 层共用此 helper。
|
||||
func validateDefaultPlatformQuotaMap(m map[string]*DefaultPlatformQuotaSetting) error {
|
||||
for platform, pq := range m {
|
||||
if !IsAllowedQuotaPlatform(platform) {
|
||||
return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", fmt.Sprintf("unknown platform %q", platform))
|
||||
}
|
||||
if pq == nil {
|
||||
continue
|
||||
}
|
||||
for _, v := range []*float64{pq.DailyLimitUSD, pq.WeeklyLimitUSD, pq.MonthlyLimitUSD} {
|
||||
if v != nil && (*v < 0 || math.IsNaN(*v) || math.IsInf(*v, 0)) {
|
||||
return infraerrors.BadRequest("INVALID_DEFAULT_PLATFORM_QUOTA", "platform quota limit must be a finite non-negative number")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
|
||||
if settings == nil {
|
||||
return nil, nil
|
||||
@ -1826,6 +1884,26 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
|
||||
}
|
||||
}
|
||||
|
||||
// 校验各 auth source 的 platform quota map(改动 C:对等系统层校验)
|
||||
for _, pgs := range []struct {
|
||||
name string
|
||||
pq map[string]*DefaultPlatformQuotaSetting
|
||||
}{
|
||||
{"email", settings.Email.PlatformQuotas},
|
||||
{"linuxdo", settings.LinuxDo.PlatformQuotas},
|
||||
{"oidc", settings.OIDC.PlatformQuotas},
|
||||
{"wechat", settings.WeChat.PlatformQuotas},
|
||||
{"github", settings.GitHub.PlatformQuotas},
|
||||
{"google", settings.Google.PlatformQuotas},
|
||||
{"dingtalk", settings.DingTalk.PlatformQuotas},
|
||||
} {
|
||||
if pgs.pq != nil {
|
||||
if err := validateDefaultPlatformQuotaMap(pgs.pq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updates := make(map[string]string, 36)
|
||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||
@ -2386,6 +2464,13 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
|
||||
SettingKeyAuthSourceDefaultDingTalkSubscriptions,
|
||||
SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
|
||||
SettingKeyAuthSourcePlatformQuotas("email"),
|
||||
SettingKeyAuthSourcePlatformQuotas("linuxdo"),
|
||||
SettingKeyAuthSourcePlatformQuotas("oidc"),
|
||||
SettingKeyAuthSourcePlatformQuotas("wechat"),
|
||||
SettingKeyAuthSourcePlatformQuotas("github"),
|
||||
SettingKeyAuthSourcePlatformQuotas("google"),
|
||||
SettingKeyAuthSourcePlatformQuotas("dingtalk"),
|
||||
SettingKeyForceEmailOnThirdPartySignup,
|
||||
}
|
||||
|
||||
@ -3179,6 +3264,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
|
||||
}
|
||||
|
||||
// 系统层默认 platform quota(修复 Bug B:parseSettings 不填充导致回显恒为 nil)
|
||||
if raw := settings[SettingKeyDefaultPlatformQuotas]; raw != "" {
|
||||
parsed := map[string]*DefaultPlatformQuotaSetting{}
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
slog.Warn("[Setting] parseSettings: unmarshal default_platform_quotas failed", "error", err)
|
||||
} else {
|
||||
result.DefaultPlatformQuotas = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@ -3271,6 +3366,15 @@ func parseProviderDefaultGrantSettings(settings map[string]string, keys authSour
|
||||
result.GrantOnFirstBind = raw == "true"
|
||||
}
|
||||
|
||||
if raw := settings[keys.platformQuotas]; raw != "" {
|
||||
parsed := map[string]*DefaultPlatformQuotaSetting{}
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
slog.Warn("[Setting] parseProviderDefaultGrantSettings: unmarshal auth source platform quotas failed", "source", keys.source, "error", err)
|
||||
} else {
|
||||
result.PlatformQuotas = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@ -3289,6 +3393,17 @@ func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSource
|
||||
updates[keys.subscriptions] = string(raw)
|
||||
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
|
||||
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
|
||||
|
||||
// auth source platform quota:整体替换语义。
|
||||
// nil = 请求未携带该字段,跳过写入以保留既有配置(与系统层 buildSystemSettingsUpdates 的
|
||||
// DefaultPlatformQuotas nil 守卫一致);非 nil(含空 map)才整体替换。二者语义不可混同。
|
||||
if keys.platformQuotas != "" && settings.PlatformQuotas != nil {
|
||||
blob, err := json.Marshal(settings.PlatformQuotas)
|
||||
if err != nil {
|
||||
blob = []byte("{}")
|
||||
}
|
||||
updates[keys.platformQuotas] = string(blob)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
|
||||
@ -4493,3 +4608,63 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
|
||||
|
||||
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
|
||||
}
|
||||
|
||||
// GetDefaultPlatformQuotas 读取系统全局 platform quota JSON key,返回 4 platform x 3 window 的设置。
|
||||
// 永远返回包含全部 4 platform key 的 map(值可能为零值/nil 字段,表示"上层未配置 = 不限制")。
|
||||
//
|
||||
// 使用单个 JSON key(default_platform_quotas),一次 DB roundtrip,消除旧 12-KV 格式的 N+1 问题。
|
||||
// 容错语义:取值失败或 unmarshal 失败 → 返回补齐 4 key 的空 map(fail-open,注册不被阻断)。
|
||||
func (s *SettingService) GetDefaultPlatformQuotas(ctx context.Context) (map[string]*DefaultPlatformQuotaSetting, error) {
|
||||
out := map[string]*DefaultPlatformQuotaSetting{
|
||||
"anthropic": {},
|
||||
"openai": {},
|
||||
"gemini": {},
|
||||
"antigravity": {},
|
||||
}
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultPlatformQuotas)
|
||||
if err != nil || raw == "" {
|
||||
return out, nil // 无配置 = 全部不限制
|
||||
}
|
||||
parsed := map[string]*DefaultPlatformQuotaSetting{}
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
slog.Warn("[Setting] unmarshal default_platform_quotas failed (fail-open)", "error", err)
|
||||
return out, nil
|
||||
}
|
||||
for _, platform := range AllowedQuotaPlatforms {
|
||||
if v := parsed[platform]; v != nil {
|
||||
out[platform] = v
|
||||
}
|
||||
}
|
||||
return out, nil // 补齐 4 platform key,保持与旧实现一致的下游契约
|
||||
}
|
||||
|
||||
// GetAuthSourcePlatformQuotas 读取指定 auth source 的 platform quota 覆盖(仅返回有配置的平台,override 语义)。
|
||||
func (s *SettingService) GetAuthSourcePlatformQuotas(ctx context.Context, source string) map[string]*DefaultPlatformQuotaSetting {
|
||||
out := map[string]*DefaultPlatformQuotaSetting{}
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAuthSourcePlatformQuotas(source))
|
||||
if err != nil || raw == "" {
|
||||
return out // 无 override
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
slog.Warn("[Setting] unmarshal auth source platform quotas failed (fail-open)", "source", source, "error", err)
|
||||
return map[string]*DefaultPlatformQuotaSetting{}
|
||||
}
|
||||
return out // 仅含已配置平台,保持 override 语义
|
||||
}
|
||||
|
||||
// mergePlatformQuotaDefaults 按字段级 patch:src 中非 nil 字段覆盖 dst。
|
||||
// 区分 nil("未配置",保留 dst)vs &0.0("显式禁用",覆盖 dst 为 0)
|
||||
func mergePlatformQuotaDefaults(dst, src *DefaultPlatformQuotaSetting) {
|
||||
if src == nil || dst == nil {
|
||||
return
|
||||
}
|
||||
if src.DailyLimitUSD != nil {
|
||||
dst.DailyLimitUSD = src.DailyLimitUSD
|
||||
}
|
||||
if src.WeeklyLimitUSD != nil {
|
||||
dst.WeeklyLimitUSD = src.WeeklyLimitUSD
|
||||
}
|
||||
if src.MonthlyLimitUSD != nil {
|
||||
dst.MonthlyLimitUSD = src.MonthlyLimitUSD
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user